Skip to content

Commit f099e6e

Browse files
feat: Add fair unified memory pool (#1369)
## Which issue does this PR close? ## Rationale for this change Current Comet unified memory pool is a greedy pool. One thread (consumer) can take a large amount of memory that causes OOM for other threads, especially for aggregation. ## What changes are included in this PR? Added a fair version of unified memory pool similar to DataFusion `FairSpilPool` that caps the memory usage at `pool_size/num` The fair unified memory pool is the default for off-heap mode with this PR ## How are these changes tested? Exisiting tests
1 parent a1e6a39 commit f099e6e

File tree

8 files changed

+186
-5
lines changed

8 files changed

+186
-5
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,8 @@ object CometConf extends ShimCometConf {
504504
.doc(
505505
"The type of memory pool to be used for Comet native execution. " +
506506
"Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', " +
507-
"'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, " +
508-
"this config is 'greedy_task_shared'.")
507+
"'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global'. For off-heap " +
508+
"types are 'unified' and `fair_unified`.")
509509
.stringConf
510510
.createWithDefault("greedy_task_shared")
511511

docs/source/user-guide/configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Comet provides the following configuration settings.
4848
| spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | true |
4949
| spark.comet.exec.initCap.enabled | Whether to enable initCap by default. | false |
5050
| spark.comet.exec.localLimit.enabled | Whether to enable localLimit by default. | true |
51-
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared |
51+
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global'. For off-heap types are 'unified' and `fair_unified`. | greedy_task_shared |
5252
| spark.comet.exec.project.enabled | Whether to enable project by default. | true |
5353
| spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
5454
| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. lz4, zstd, and snappy are supported. Compression can be disabled by setting spark.shuffle.compress=false. | lz4 |

native/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ datafusion-comet-proto = { workspace = true }
7777
object_store = { workspace = true }
7878
url = { workspace = true }
7979
chrono = { workspace = true }
80+
parking_lot = "0.12.3"
8081

8182
[dev-dependencies]
8283
pprof = { version = "0.14.0", features = ["flamegraph"] }
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::{
19+
fmt::{Debug, Formatter, Result as FmtResult},
20+
sync::Arc,
21+
};
22+
23+
use jni::objects::GlobalRef;
24+
25+
use crate::{
26+
errors::CometResult,
27+
jvm_bridge::{jni_call, JVMClasses},
28+
};
29+
use datafusion::{
30+
common::DataFusionError,
31+
execution::memory_pool::{MemoryPool, MemoryReservation},
32+
};
33+
use datafusion_common::resources_err;
34+
use datafusion_execution::memory_pool::MemoryConsumer;
35+
use parking_lot::Mutex;
36+
37+
/// A DataFusion fair `MemoryPool` implementation for Comet. Internally this is
38+
/// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`].
39+
pub struct CometFairMemoryPool {
40+
task_memory_manager_handle: Arc<GlobalRef>,
41+
pool_size: usize,
42+
state: Mutex<CometFairPoolState>,
43+
}
44+
45+
struct CometFairPoolState {
46+
used: usize,
47+
num: usize,
48+
}
49+
50+
impl Debug for CometFairMemoryPool {
51+
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
52+
let state = self.state.lock();
53+
f.debug_struct("CometFairMemoryPool")
54+
.field("pool_size", &self.pool_size)
55+
.field("used", &state.used)
56+
.field("num", &state.num)
57+
.finish()
58+
}
59+
}
60+
61+
impl CometFairMemoryPool {
62+
pub fn new(
63+
task_memory_manager_handle: Arc<GlobalRef>,
64+
pool_size: usize,
65+
) -> CometFairMemoryPool {
66+
Self {
67+
task_memory_manager_handle,
68+
pool_size,
69+
state: Mutex::new(CometFairPoolState { used: 0, num: 0 }),
70+
}
71+
}
72+
73+
fn acquire(&self, additional: usize) -> CometResult<i64> {
74+
let mut env = JVMClasses::get_env()?;
75+
let handle = self.task_memory_manager_handle.as_obj();
76+
unsafe {
77+
jni_call!(&mut env,
78+
comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64)
79+
}
80+
}
81+
82+
fn release(&self, size: usize) -> CometResult<()> {
83+
let mut env = JVMClasses::get_env()?;
84+
let handle = self.task_memory_manager_handle.as_obj();
85+
unsafe {
86+
jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ())
87+
}
88+
}
89+
}
90+
91+
unsafe impl Send for CometFairMemoryPool {}
92+
unsafe impl Sync for CometFairMemoryPool {}
93+
94+
impl MemoryPool for CometFairMemoryPool {
95+
fn register(&self, _: &MemoryConsumer) {
96+
let mut state = self.state.lock();
97+
state.num = state
98+
.num
99+
.checked_add(1)
100+
.expect("unexpected amount of register happened");
101+
}
102+
103+
fn unregister(&self, _: &MemoryConsumer) {
104+
let mut state = self.state.lock();
105+
state.num = state
106+
.num
107+
.checked_sub(1)
108+
.expect("unexpected amount of unregister happened");
109+
}
110+
111+
fn grow(&self, reservation: &MemoryReservation, additional: usize) {
112+
self.try_grow(reservation, additional).unwrap();
113+
}
114+
115+
fn shrink(&self, reservation: &MemoryReservation, subtractive: usize) {
116+
if subtractive > 0 {
117+
let mut state = self.state.lock();
118+
let size = reservation.size();
119+
if size < subtractive {
120+
panic!("Failed to release {subtractive} bytes where only {size} bytes reserved")
121+
}
122+
self.release(subtractive)
123+
.unwrap_or_else(|_| panic!("Failed to release {} bytes", subtractive));
124+
state.used = state.used.checked_sub(subtractive).unwrap();
125+
}
126+
}
127+
128+
fn try_grow(
129+
&self,
130+
reservation: &MemoryReservation,
131+
additional: usize,
132+
) -> Result<(), DataFusionError> {
133+
if additional > 0 {
134+
let mut state = self.state.lock();
135+
let num = state.num;
136+
let limit = self.pool_size.checked_div(num).unwrap();
137+
let size = reservation.size();
138+
if limit < size + additional {
139+
return resources_err!(
140+
"Failed to acquire {additional} bytes where {size} bytes already reserved and the fair limit is {limit} bytes, {num} registered"
141+
);
142+
}
143+
144+
let acquired = self.acquire(additional)?;
145+
// If the number of bytes we acquired is less than the requested, return an error,
146+
// and hopefully will trigger spilling from the caller side.
147+
if acquired < additional as i64 {
148+
// Release the acquired bytes before throwing error
149+
self.release(acquired as usize)?;
150+
151+
return resources_err!(
152+
"Failed to acquire {} bytes, only got {} bytes. Reserved: {} bytes",
153+
additional,
154+
acquired,
155+
state.used
156+
);
157+
}
158+
state.used = state.used.checked_add(additional).unwrap();
159+
}
160+
Ok(())
161+
}
162+
163+
fn reserved(&self) -> usize {
164+
self.state.lock().used
165+
}
166+
}

native/core/src/execution/jni_api.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use std::num::NonZeroUsize;
6363
use std::sync::Mutex;
6464
use tokio::runtime::Runtime;
6565

66+
use crate::execution::fair_memory_pool::CometFairMemoryPool;
6667
use crate::execution::operators::ScanExec;
6768
use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec};
6869
use crate::execution::spark_plan::SparkPlan;
@@ -110,6 +111,7 @@ struct ExecutionContext {
110111
#[derive(PartialEq, Eq)]
111112
enum MemoryPoolType {
112113
Unified,
114+
FairUnified,
113115
Greedy,
114116
FairSpill,
115117
GreedyTaskShared,
@@ -295,11 +297,14 @@ fn parse_memory_pool_config(
295297
memory_limit: i64,
296298
memory_limit_per_task: i64,
297299
) -> CometResult<MemoryPoolConfig> {
300+
let pool_size = memory_limit as usize;
298301
let memory_pool_config = if use_unified_memory_manager {
299-
MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
302+
match memory_pool_type.as_str() {
303+
"fair_unified" => MemoryPoolConfig::new(MemoryPoolType::FairUnified, pool_size),
304+
_ => MemoryPoolConfig::new(MemoryPoolType::Unified, 0),
305+
}
300306
} else {
301307
// Use the memory pool from DF
302-
let pool_size = memory_limit as usize;
303308
let pool_size_per_task = memory_limit_per_task as usize;
304309
match memory_pool_type.as_str() {
305310
"fair_spill_task_shared" => {
@@ -337,6 +342,12 @@ fn create_memory_pool(
337342
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
338343
Arc::new(memory_pool)
339344
}
345+
MemoryPoolType::FairUnified => {
346+
// Set Comet fair memory pool for native
347+
let memory_pool =
348+
CometFairMemoryPool::new(comet_task_memory_manager, memory_pool_config.pool_size);
349+
Arc::new(memory_pool)
350+
}
340351
MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
341352
GreedyMemoryPool::new(memory_pool_config.pool_size),
342353
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),

native/core/src/execution/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub(crate) mod util;
2929
pub use datafusion_comet_spark_expr::timezone;
3030
pub(crate) mod utils;
3131

32+
mod fair_memory_pool;
3233
mod memory_pool;
3334
pub use memory_pool::*;
3435

spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class CometTPCHQuerySuite extends QueryTest with TPCBase with ShimCometTPCHQuery
9494
conf.set(CometConf.COMET_SHUFFLE_MODE.key, "jvm")
9595
conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
9696
conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
97+
conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g")
9798
}
9899

99100
protected override def createSparkSession: TestSparkSession = {

0 commit comments

Comments
 (0)