Skip to content

Commit 5a7f83c

Browse files
committed
save progress [skip ci]
1 parent 8b0b12d commit 5a7f83c

10 files changed

Lines changed: 210 additions & 192 deletions

File tree

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,17 @@ object CometConf extends ShimCometConf {
318318
.checkValue(v => v >= 0, "The fast path threshold must be non-negative.")
319319
.createWithDefault(64L * 1024 * 1024) // 64 MB
320320

321+
val COMET_EXEC_GRACE_HASH_JOIN_MAX_CONCURRENT_PARTITIONS: ConfigEntry[Int] =
322+
conf(s"$COMET_EXEC_CONFIG_PREFIX.graceHashJoin.maxConcurrentPartitions")
323+
.category(CATEGORY_EXEC)
324+
.doc(
325+
"Maximum number of partitions to join in parallel during Grace Hash Join's " +
326+
"slow path. Higher values improve latency at the cost of concurrent hash-table " +
327+
"memory. Keep low when many Spark tasks share a single memory pool.")
328+
.intConf
329+
.checkValue(v => v > 0, "The max concurrent partitions must be positive.")
330+
.createWithDefault(2)
331+
321332
val COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED: ConfigEntry[Boolean] =
322333
conf(s"$COMET_EXEC_CONFIG_PREFIX.columnarToRow.native.enabled")
323334
.category(CATEGORY_EXEC)

native/core/src/execution/operators/grace_hash_join.rs

Lines changed: 42 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ pub struct GraceHashJoinExec {
484484
build_left: bool,
485485
/// Maximum build-side bytes for the fast path (0 = disabled)
486486
fast_path_threshold: usize,
487+
/// Maximum number of partitions to join concurrently in Phase 3
488+
max_concurrent_partitions: usize,
487489
/// Output schema
488490
schema: SchemaRef,
489491
/// Plan properties cache
@@ -503,6 +505,7 @@ impl GraceHashJoinExec {
503505
num_partitions: usize,
504506
build_left: bool,
505507
fast_path_threshold: usize,
508+
max_concurrent_partitions: usize,
506509
) -> DFResult<Self> {
507510
// Build the output schema using HashJoinExec's logic.
508511
// HashJoinExec expects left=build, right=probe. When build_left=false,
@@ -540,6 +543,7 @@ impl GraceHashJoinExec {
540543
},
541544
build_left,
542545
fast_path_threshold,
546+
max_concurrent_partitions: max_concurrent_partitions.max(1),
543547
schema,
544548
cache,
545549
metrics: ExecutionPlanMetricsSet::new(),
@@ -596,6 +600,7 @@ impl ExecutionPlan for GraceHashJoinExec {
596600
self.num_partitions,
597601
self.build_left,
598602
self.fast_path_threshold,
603+
self.max_concurrent_partitions,
599604
)?))
600605
}
601606

@@ -657,6 +662,7 @@ impl ExecutionPlan for GraceHashJoinExec {
657662
let num_partitions = self.num_partitions;
658663
let build_left = self.build_left;
659664
let fast_path_threshold = self.fast_path_threshold;
665+
let max_concurrent_partitions = self.max_concurrent_partitions;
660666

661667
let result_stream = futures::stream::once(async move {
662668
execute_grace_hash_join(
@@ -670,6 +676,7 @@ impl ExecutionPlan for GraceHashJoinExec {
670676
num_partitions,
671677
build_left,
672678
fast_path_threshold,
679+
max_concurrent_partitions,
673680
build_schema,
674681
probe_schema,
675682
context,
@@ -750,6 +757,7 @@ async fn execute_grace_hash_join(
750757
num_partitions: usize,
751758
build_left: bool,
752759
fast_path_threshold: usize,
760+
max_concurrent_partitions: usize,
753761
build_schema: SchemaRef,
754762
probe_schema: SchemaRef,
755763
context: Arc<TaskContext>,
@@ -758,11 +766,9 @@ async fn execute_grace_hash_join(
758766
let ghj_id = GHJ_INSTANCE_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
759767

760768
// Set up memory reservation (shared across build and probe phases)
761-
let mut reservation = MutableReservation(
762-
MemoryConsumer::new("GraceHashJoinExec")
763-
.with_can_spill(true)
764-
.register(&context.runtime_env().memory_pool),
765-
);
769+
let mut reservation = MemoryConsumer::new("GraceHashJoinExec")
770+
.with_can_spill(true)
771+
.register(&context.runtime_env().memory_pool);
766772

767773
info!(
768774
"GHJ#{}: started. build_left={}, join_type={:?}, pool reserved={}",
@@ -903,6 +909,7 @@ async fn execute_grace_hash_join(
903909
join_type,
904910
num_partitions,
905911
build_left,
912+
max_concurrent_partitions,
906913
build_schema,
907914
probe_schema,
908915
context,
@@ -948,6 +955,7 @@ async fn execute_grace_hash_join(
948955
join_type,
949956
num_partitions,
950957
build_left,
958+
max_concurrent_partitions,
951959
build_schema,
952960
probe_schema,
953961
context,
@@ -971,7 +979,7 @@ enum BuildBufferResult {
971979
/// or signals memory pressure with the partially-buffered data and remaining stream.
972980
async fn buffer_build_optimistic(
973981
mut input: SendableRecordBatchStream,
974-
reservation: &mut MutableReservation,
982+
reservation: &mut MemoryReservation,
975983
metrics: &GraceHashJoinMetrics,
976984
) -> DFResult<BuildBufferResult> {
977985
let mut batches = Vec::new();
@@ -1011,7 +1019,7 @@ fn partition_from_buffer(
10111019
num_partitions: usize,
10121020
schema: &SchemaRef,
10131021
partitions: &mut [HashPartition],
1014-
reservation: &mut MutableReservation,
1022+
reservation: &mut MemoryReservation,
10151023
context: &Arc<TaskContext>,
10161024
metrics: &GraceHashJoinMetrics,
10171025
scratch: &mut ScratchSpace,
@@ -1021,7 +1029,6 @@ fn partition_from_buffer(
10211029
continue;
10221030
}
10231031

1024-
let total_batch_size = batch.get_array_memory_size();
10251032
let total_rows = batch.num_rows();
10261033

10271034
scratch.compute_partitions(&batch, keys, num_partitions, 0)?;
@@ -1038,11 +1045,7 @@ fn partition_from_buffer(
10381045
} else {
10391046
scratch.take_partition(&batch, part_idx)?.unwrap()
10401047
};
1041-
let batch_size = if total_rows > 0 {
1042-
(total_batch_size as u64 * sub_rows as u64 / total_rows as u64) as usize
1043-
} else {
1044-
0
1045-
};
1048+
let batch_size = sub_batch.get_array_memory_size();
10461049

10471050
if partitions[part_idx].build_spilled() {
10481051
if let Some(ref mut writer) = partitions[part_idx].build_spill_writer {
@@ -1173,11 +1176,12 @@ async fn execute_slow_path(
11731176
join_type: JoinType,
11741177
num_partitions: usize,
11751178
build_left: bool,
1179+
max_concurrent_partitions: usize,
11761180
build_schema: SchemaRef,
11771181
probe_schema: SchemaRef,
11781182
context: Arc<TaskContext>,
11791183
metrics: GraceHashJoinMetrics,
1180-
mut reservation: MutableReservation,
1184+
mut reservation: MemoryReservation,
11811185
mut scratch: ScratchSpace,
11821186
) -> DFResult<impl Stream<Item = DFResult<RecordBatch>>> {
11831187
let build_spilled = partitions.iter().any(|p| p.build_spilled());
@@ -1241,7 +1245,7 @@ async fn execute_slow_path(
12411245
total_probe_rows,
12421246
total_probe_bytes,
12431247
probe_spilled,
1244-
reservation.0.size(),
1248+
reservation.size(),
12451249
context.runtime_env().memory_pool.reserved(),
12461250
);
12471251
}
@@ -1279,19 +1283,17 @@ async fn execute_slow_path(
12791283
info!(
12801284
"GHJ#{}: freeing reservation ({} bytes) before Phase 3. pool reserved={}",
12811285
ghj_id,
1282-
reservation.0.size(),
1286+
reservation.size(),
12831287
context.runtime_env().memory_pool.reserved(),
12841288
);
12851289
reservation.free();
12861290

1287-
// Phase 3: Join partitions sequentially.
1288-
// We use a concurrency limit of 1 to avoid creating multiple simultaneous
1289-
// HashJoinInput reservations per task. With multiple Spark tasks sharing
1290-
// the same memory pool, even modest build sides (e.g. 22 MB) can exhaust
1291-
// memory when many tasks run concurrent hash table builds simultaneously.
1292-
const MAX_CONCURRENT_PARTITIONS: usize = 1;
1293-
let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_PARTITIONS));
1294-
let (tx, rx) = mpsc::channel::<DFResult<RecordBatch>>(MAX_CONCURRENT_PARTITIONS * 2);
1291+
// Phase 3: Join partitions with bounded concurrency. Keeping this low
1292+
// avoids creating many simultaneous HashJoinInput reservations per task
1293+
// when multiple Spark tasks share the same memory pool.
1294+
let max_concurrent_partitions = max_concurrent_partitions.max(1);
1295+
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent_partitions));
1296+
let (tx, rx) = mpsc::channel::<DFResult<RecordBatch>>(max_concurrent_partitions * 2);
12951297

12961298
for partition in finished_partitions {
12971299
let tx = tx.clone();
@@ -1365,23 +1367,6 @@ async fn execute_slow_path(
13651367
Ok(result_stream.boxed())
13661368
}
13671369

1368-
/// Wraps MemoryReservation to allow mutation through reference.
1369-
struct MutableReservation(MemoryReservation);
1370-
1371-
impl MutableReservation {
1372-
fn try_grow(&mut self, additional: usize) -> DFResult<()> {
1373-
self.0.try_grow(additional)
1374-
}
1375-
1376-
fn shrink(&mut self, amount: usize) {
1377-
self.0.shrink(amount);
1378-
}
1379-
1380-
fn free(&mut self) -> usize {
1381-
self.0.free()
1382-
}
1383-
}
1384-
13851370
// ---------------------------------------------------------------------------
13861371
// ScratchSpace: reusable buffers for efficient hash partitioning
13871372
// ---------------------------------------------------------------------------
@@ -1421,10 +1406,11 @@ impl ScratchSpace {
14211406
.map(|expr| expr.evaluate(batch).and_then(|cv| cv.into_array(num_rows)))
14221407
.collect::<DFResult<Vec<_>>>()?;
14231408

1424-
// Hash
1409+
// Hash. `create_hashes` XORs into the existing values, so the buffer
1410+
// must be zeroed. `clear()` + `resize()` produces a fresh zeroed buffer
1411+
// of the right length regardless of its previous size.
1412+
self.hashes.clear();
14251413
self.hashes.resize(num_rows, 0);
1426-
self.hashes.truncate(num_rows);
1427-
self.hashes.fill(0);
14281414
let random_state = partition_random_state(recursion_level);
14291415
create_hashes(&key_columns, &random_state, &mut self.hashes)?;
14301416

@@ -1526,7 +1512,7 @@ async fn partition_build_side(
15261512
num_partitions: usize,
15271513
schema: &SchemaRef,
15281514
partitions: &mut [HashPartition],
1529-
reservation: &mut MutableReservation,
1515+
reservation: &mut MemoryReservation,
15301516
context: &Arc<TaskContext>,
15311517
metrics: &GraceHashJoinMetrics,
15321518
scratch: &mut ScratchSpace,
@@ -1540,8 +1526,6 @@ async fn partition_build_side(
15401526
metrics.build_input_batches.add(1);
15411527
metrics.build_input_rows.add(batch.num_rows());
15421528

1543-
// Track total batch size once, estimate per-partition proportionally
1544-
let total_batch_size = batch.get_array_memory_size();
15451529
let total_rows = batch.num_rows();
15461530

15471531
scratch.compute_partitions(&batch, keys, num_partitions, 0)?;
@@ -1558,11 +1542,7 @@ async fn partition_build_side(
15581542
} else {
15591543
scratch.take_partition(&batch, part_idx)?.unwrap()
15601544
};
1561-
let batch_size = if total_rows > 0 {
1562-
(total_batch_size as u64 * sub_rows as u64 / total_rows as u64) as usize
1563-
} else {
1564-
0
1565-
};
1545+
let batch_size = sub_batch.get_array_memory_size();
15661546

15671547
if partitions[part_idx].build_spilled() {
15681548
// This partition is already spilled; append incrementally
@@ -1613,7 +1593,7 @@ fn spill_largest_partition(
16131593
partitions: &mut [HashPartition],
16141594
schema: &SchemaRef,
16151595
context: &Arc<TaskContext>,
1616-
reservation: &mut MutableReservation,
1596+
reservation: &mut MemoryReservation,
16171597
metrics: &GraceHashJoinMetrics,
16181598
) -> DFResult<()> {
16191599
// Find the largest non-spilled partition
@@ -1642,7 +1622,7 @@ fn spill_partition_build(
16421622
partition: &mut HashPartition,
16431623
schema: &SchemaRef,
16441624
context: &Arc<TaskContext>,
1645-
reservation: &mut MutableReservation,
1625+
reservation: &mut MemoryReservation,
16461626
metrics: &GraceHashJoinMetrics,
16471627
) -> DFResult<()> {
16481628
let temp_file = context
@@ -1672,7 +1652,7 @@ fn spill_partition_probe(
16721652
partition: &mut HashPartition,
16731653
schema: &SchemaRef,
16741654
context: &Arc<TaskContext>,
1675-
reservation: &mut MutableReservation,
1655+
reservation: &mut MemoryReservation,
16761656
metrics: &GraceHashJoinMetrics,
16771657
) -> DFResult<()> {
16781658
if partition.probe_batches.is_empty() && partition.probe_spill_writer.is_some() {
@@ -1708,7 +1688,7 @@ fn spill_partition_both_sides(
17081688
probe_schema: &SchemaRef,
17091689
build_schema: &SchemaRef,
17101690
context: &Arc<TaskContext>,
1711-
reservation: &mut MutableReservation,
1691+
reservation: &mut MemoryReservation,
17121692
metrics: &GraceHashJoinMetrics,
17131693
) -> DFResult<()> {
17141694
if !partition.build_spilled() {
@@ -1734,7 +1714,7 @@ async fn partition_probe_side(
17341714
num_partitions: usize,
17351715
schema: &SchemaRef,
17361716
partitions: &mut [HashPartition],
1737-
reservation: &mut MutableReservation,
1717+
reservation: &mut MemoryReservation,
17381718
build_schema: &SchemaRef,
17391719
context: &Arc<TaskContext>,
17401720
metrics: &GraceHashJoinMetrics,
@@ -1754,7 +1734,7 @@ async fn partition_probe_side(
17541734
"GraceHashJoin: probe accumulation progress: {} rows, \
17551735
reservation={}, pool reserved={}",
17561736
probe_rows_accumulated,
1757-
reservation.0.size(),
1737+
reservation.size(),
17581738
context.runtime_env().memory_pool.reserved(),
17591739
);
17601740
}
@@ -2600,6 +2580,7 @@ mod tests {
26002580
4, // Use 4 partitions for testing
26012581
true,
26022582
10 * 1024 * 1024, // 10 MB fast path threshold
2583+
2, // max_concurrent_partitions
26032584
)?;
26042585

26052586
let stream = grace_join.execute(0, task_ctx)?;
@@ -2654,6 +2635,7 @@ mod tests {
26542635
4,
26552636
true,
26562637
10 * 1024 * 1024, // 10 MB fast path threshold
2638+
2, // max_concurrent_partitions
26572639
)?;
26582640

26592641
let stream = grace_join.execute(0, task_ctx)?;
@@ -2761,6 +2743,7 @@ mod tests {
27612743
16,
27622744
true, // build_left
27632745
0, // fast_path_threshold = 0 (disabled)
2746+
2, // max_concurrent_partitions
27642747
)?;
27652748

27662749
let stream = grace_join.execute(0, task_ctx)?;
@@ -2824,6 +2807,7 @@ mod tests {
28242807
16,
28252808
false, // build_left=false → right is build side
28262809
0, // fast_path_threshold = 0 (disabled)
2810+
2, // max_concurrent_partitions
28272811
)?;
28282812

28292813
let stream = grace_join.execute(0, task_ctx)?;

0 commit comments

Comments
 (0)