@@ -484,6 +484,8 @@ pub struct GraceHashJoinExec {
484484 build_left : bool ,
485485 /// Maximum build-side bytes for the fast path (0 = disabled)
486486 fast_path_threshold : usize ,
487+ /// Maximum number of partitions to join concurrently in Phase 3
488+ max_concurrent_partitions : usize ,
487489 /// Output schema
488490 schema : SchemaRef ,
489491 /// Plan properties cache
@@ -503,6 +505,7 @@ impl GraceHashJoinExec {
503505 num_partitions : usize ,
504506 build_left : bool ,
505507 fast_path_threshold : usize ,
508+ max_concurrent_partitions : usize ,
506509 ) -> DFResult < Self > {
507510 // Build the output schema using HashJoinExec's logic.
508511 // HashJoinExec expects left=build, right=probe. When build_left=false,
@@ -540,6 +543,7 @@ impl GraceHashJoinExec {
540543 } ,
541544 build_left,
542545 fast_path_threshold,
546+ max_concurrent_partitions : max_concurrent_partitions. max ( 1 ) ,
543547 schema,
544548 cache,
545549 metrics : ExecutionPlanMetricsSet :: new ( ) ,
@@ -596,6 +600,7 @@ impl ExecutionPlan for GraceHashJoinExec {
596600 self . num_partitions ,
597601 self . build_left ,
598602 self . fast_path_threshold ,
603+ self . max_concurrent_partitions ,
599604 ) ?) )
600605 }
601606
@@ -657,6 +662,7 @@ impl ExecutionPlan for GraceHashJoinExec {
657662 let num_partitions = self . num_partitions ;
658663 let build_left = self . build_left ;
659664 let fast_path_threshold = self . fast_path_threshold ;
665+ let max_concurrent_partitions = self . max_concurrent_partitions ;
660666
661667 let result_stream = futures:: stream:: once ( async move {
662668 execute_grace_hash_join (
@@ -670,6 +676,7 @@ impl ExecutionPlan for GraceHashJoinExec {
670676 num_partitions,
671677 build_left,
672678 fast_path_threshold,
679+ max_concurrent_partitions,
673680 build_schema,
674681 probe_schema,
675682 context,
@@ -750,6 +757,7 @@ async fn execute_grace_hash_join(
750757 num_partitions : usize ,
751758 build_left : bool ,
752759 fast_path_threshold : usize ,
760+ max_concurrent_partitions : usize ,
753761 build_schema : SchemaRef ,
754762 probe_schema : SchemaRef ,
755763 context : Arc < TaskContext > ,
@@ -758,11 +766,9 @@ async fn execute_grace_hash_join(
758766 let ghj_id = GHJ_INSTANCE_COUNTER . fetch_add ( 1 , std:: sync:: atomic:: Ordering :: Relaxed ) ;
759767
760768 // Set up memory reservation (shared across build and probe phases)
761- let mut reservation = MutableReservation (
762- MemoryConsumer :: new ( "GraceHashJoinExec" )
763- . with_can_spill ( true )
764- . register ( & context. runtime_env ( ) . memory_pool ) ,
765- ) ;
769+ let mut reservation = MemoryConsumer :: new ( "GraceHashJoinExec" )
770+ . with_can_spill ( true )
771+ . register ( & context. runtime_env ( ) . memory_pool ) ;
766772
767773 info ! (
768774 "GHJ#{}: started. build_left={}, join_type={:?}, pool reserved={}" ,
@@ -903,6 +909,7 @@ async fn execute_grace_hash_join(
903909 join_type,
904910 num_partitions,
905911 build_left,
912+ max_concurrent_partitions,
906913 build_schema,
907914 probe_schema,
908915 context,
@@ -948,6 +955,7 @@ async fn execute_grace_hash_join(
948955 join_type,
949956 num_partitions,
950957 build_left,
958+ max_concurrent_partitions,
951959 build_schema,
952960 probe_schema,
953961 context,
@@ -971,7 +979,7 @@ enum BuildBufferResult {
971979/// or signals memory pressure with the partially-buffered data and remaining stream.
972980async fn buffer_build_optimistic (
973981 mut input : SendableRecordBatchStream ,
974- reservation : & mut MutableReservation ,
982+ reservation : & mut MemoryReservation ,
975983 metrics : & GraceHashJoinMetrics ,
976984) -> DFResult < BuildBufferResult > {
977985 let mut batches = Vec :: new ( ) ;
@@ -1011,7 +1019,7 @@ fn partition_from_buffer(
10111019 num_partitions : usize ,
10121020 schema : & SchemaRef ,
10131021 partitions : & mut [ HashPartition ] ,
1014- reservation : & mut MutableReservation ,
1022+ reservation : & mut MemoryReservation ,
10151023 context : & Arc < TaskContext > ,
10161024 metrics : & GraceHashJoinMetrics ,
10171025 scratch : & mut ScratchSpace ,
@@ -1021,7 +1029,6 @@ fn partition_from_buffer(
10211029 continue ;
10221030 }
10231031
1024- let total_batch_size = batch. get_array_memory_size ( ) ;
10251032 let total_rows = batch. num_rows ( ) ;
10261033
10271034 scratch. compute_partitions ( & batch, keys, num_partitions, 0 ) ?;
@@ -1038,11 +1045,7 @@ fn partition_from_buffer(
10381045 } else {
10391046 scratch. take_partition ( & batch, part_idx) ?. unwrap ( )
10401047 } ;
1041- let batch_size = if total_rows > 0 {
1042- ( total_batch_size as u64 * sub_rows as u64 / total_rows as u64 ) as usize
1043- } else {
1044- 0
1045- } ;
1048+ let batch_size = sub_batch. get_array_memory_size ( ) ;
10461049
10471050 if partitions[ part_idx] . build_spilled ( ) {
10481051 if let Some ( ref mut writer) = partitions[ part_idx] . build_spill_writer {
@@ -1173,11 +1176,12 @@ async fn execute_slow_path(
11731176 join_type : JoinType ,
11741177 num_partitions : usize ,
11751178 build_left : bool ,
1179+ max_concurrent_partitions : usize ,
11761180 build_schema : SchemaRef ,
11771181 probe_schema : SchemaRef ,
11781182 context : Arc < TaskContext > ,
11791183 metrics : GraceHashJoinMetrics ,
1180- mut reservation : MutableReservation ,
1184+ mut reservation : MemoryReservation ,
11811185 mut scratch : ScratchSpace ,
11821186) -> DFResult < impl Stream < Item = DFResult < RecordBatch > > > {
11831187 let build_spilled = partitions. iter ( ) . any ( |p| p. build_spilled ( ) ) ;
@@ -1241,7 +1245,7 @@ async fn execute_slow_path(
12411245 total_probe_rows,
12421246 total_probe_bytes,
12431247 probe_spilled,
1244- reservation. 0 . size( ) ,
1248+ reservation. size( ) ,
12451249 context. runtime_env( ) . memory_pool. reserved( ) ,
12461250 ) ;
12471251 }
@@ -1279,19 +1283,17 @@ async fn execute_slow_path(
12791283 info ! (
12801284 "GHJ#{}: freeing reservation ({} bytes) before Phase 3. pool reserved={}" ,
12811285 ghj_id,
1282- reservation. 0 . size( ) ,
1286+ reservation. size( ) ,
12831287 context. runtime_env( ) . memory_pool. reserved( ) ,
12841288 ) ;
12851289 reservation. free ( ) ;
12861290
1287- // Phase 3: Join partitions sequentially.
1288- // We use a concurrency limit of 1 to avoid creating multiple simultaneous
1289- // HashJoinInput reservations per task. With multiple Spark tasks sharing
1290- // the same memory pool, even modest build sides (e.g. 22 MB) can exhaust
1291- // memory when many tasks run concurrent hash table builds simultaneously.
1292- const MAX_CONCURRENT_PARTITIONS : usize = 1 ;
1293- let semaphore = Arc :: new ( tokio:: sync:: Semaphore :: new ( MAX_CONCURRENT_PARTITIONS ) ) ;
1294- let ( tx, rx) = mpsc:: channel :: < DFResult < RecordBatch > > ( MAX_CONCURRENT_PARTITIONS * 2 ) ;
1291+ // Phase 3: Join partitions with bounded concurrency. Keeping this low
1292+ // avoids creating many simultaneous HashJoinInput reservations per task
1293+ // when multiple Spark tasks share the same memory pool.
1294+ let max_concurrent_partitions = max_concurrent_partitions. max ( 1 ) ;
1295+ let semaphore = Arc :: new ( tokio:: sync:: Semaphore :: new ( max_concurrent_partitions) ) ;
1296+ let ( tx, rx) = mpsc:: channel :: < DFResult < RecordBatch > > ( max_concurrent_partitions * 2 ) ;
12951297
12961298 for partition in finished_partitions {
12971299 let tx = tx. clone ( ) ;
@@ -1365,23 +1367,6 @@ async fn execute_slow_path(
13651367 Ok ( result_stream. boxed ( ) )
13661368}
13671369
1368- /// Wraps MemoryReservation to allow mutation through reference.
1369- struct MutableReservation ( MemoryReservation ) ;
1370-
1371- impl MutableReservation {
1372- fn try_grow ( & mut self , additional : usize ) -> DFResult < ( ) > {
1373- self . 0 . try_grow ( additional)
1374- }
1375-
1376- fn shrink ( & mut self , amount : usize ) {
1377- self . 0 . shrink ( amount) ;
1378- }
1379-
1380- fn free ( & mut self ) -> usize {
1381- self . 0 . free ( )
1382- }
1383- }
1384-
13851370// ---------------------------------------------------------------------------
13861371// ScratchSpace: reusable buffers for efficient hash partitioning
13871372// ---------------------------------------------------------------------------
@@ -1421,10 +1406,11 @@ impl ScratchSpace {
14211406 . map ( |expr| expr. evaluate ( batch) . and_then ( |cv| cv. into_array ( num_rows) ) )
14221407 . collect :: < DFResult < Vec < _ > > > ( ) ?;
14231408
1424- // Hash
1409+ // Hash. `create_hashes` XORs into the existing values, so the buffer
1410+ // must be zeroed. `clear()` + `resize()` produces a fresh zeroed buffer
1411+ // of the right length regardless of its previous size.
1412+ self . hashes . clear ( ) ;
14251413 self . hashes . resize ( num_rows, 0 ) ;
1426- self . hashes . truncate ( num_rows) ;
1427- self . hashes . fill ( 0 ) ;
14281414 let random_state = partition_random_state ( recursion_level) ;
14291415 create_hashes ( & key_columns, & random_state, & mut self . hashes ) ?;
14301416
@@ -1526,7 +1512,7 @@ async fn partition_build_side(
15261512 num_partitions : usize ,
15271513 schema : & SchemaRef ,
15281514 partitions : & mut [ HashPartition ] ,
1529- reservation : & mut MutableReservation ,
1515+ reservation : & mut MemoryReservation ,
15301516 context : & Arc < TaskContext > ,
15311517 metrics : & GraceHashJoinMetrics ,
15321518 scratch : & mut ScratchSpace ,
@@ -1540,8 +1526,6 @@ async fn partition_build_side(
15401526 metrics. build_input_batches . add ( 1 ) ;
15411527 metrics. build_input_rows . add ( batch. num_rows ( ) ) ;
15421528
1543- // Track total batch size once, estimate per-partition proportionally
1544- let total_batch_size = batch. get_array_memory_size ( ) ;
15451529 let total_rows = batch. num_rows ( ) ;
15461530
15471531 scratch. compute_partitions ( & batch, keys, num_partitions, 0 ) ?;
@@ -1558,11 +1542,7 @@ async fn partition_build_side(
15581542 } else {
15591543 scratch. take_partition ( & batch, part_idx) ?. unwrap ( )
15601544 } ;
1561- let batch_size = if total_rows > 0 {
1562- ( total_batch_size as u64 * sub_rows as u64 / total_rows as u64 ) as usize
1563- } else {
1564- 0
1565- } ;
1545+ let batch_size = sub_batch. get_array_memory_size ( ) ;
15661546
15671547 if partitions[ part_idx] . build_spilled ( ) {
15681548 // This partition is already spilled; append incrementally
@@ -1613,7 +1593,7 @@ fn spill_largest_partition(
16131593 partitions : & mut [ HashPartition ] ,
16141594 schema : & SchemaRef ,
16151595 context : & Arc < TaskContext > ,
1616- reservation : & mut MutableReservation ,
1596+ reservation : & mut MemoryReservation ,
16171597 metrics : & GraceHashJoinMetrics ,
16181598) -> DFResult < ( ) > {
16191599 // Find the largest non-spilled partition
@@ -1642,7 +1622,7 @@ fn spill_partition_build(
16421622 partition : & mut HashPartition ,
16431623 schema : & SchemaRef ,
16441624 context : & Arc < TaskContext > ,
1645- reservation : & mut MutableReservation ,
1625+ reservation : & mut MemoryReservation ,
16461626 metrics : & GraceHashJoinMetrics ,
16471627) -> DFResult < ( ) > {
16481628 let temp_file = context
@@ -1672,7 +1652,7 @@ fn spill_partition_probe(
16721652 partition : & mut HashPartition ,
16731653 schema : & SchemaRef ,
16741654 context : & Arc < TaskContext > ,
1675- reservation : & mut MutableReservation ,
1655+ reservation : & mut MemoryReservation ,
16761656 metrics : & GraceHashJoinMetrics ,
16771657) -> DFResult < ( ) > {
16781658 if partition. probe_batches . is_empty ( ) && partition. probe_spill_writer . is_some ( ) {
@@ -1708,7 +1688,7 @@ fn spill_partition_both_sides(
17081688 probe_schema : & SchemaRef ,
17091689 build_schema : & SchemaRef ,
17101690 context : & Arc < TaskContext > ,
1711- reservation : & mut MutableReservation ,
1691+ reservation : & mut MemoryReservation ,
17121692 metrics : & GraceHashJoinMetrics ,
17131693) -> DFResult < ( ) > {
17141694 if !partition. build_spilled ( ) {
@@ -1734,7 +1714,7 @@ async fn partition_probe_side(
17341714 num_partitions : usize ,
17351715 schema : & SchemaRef ,
17361716 partitions : & mut [ HashPartition ] ,
1737- reservation : & mut MutableReservation ,
1717+ reservation : & mut MemoryReservation ,
17381718 build_schema : & SchemaRef ,
17391719 context : & Arc < TaskContext > ,
17401720 metrics : & GraceHashJoinMetrics ,
@@ -1754,7 +1734,7 @@ async fn partition_probe_side(
17541734 "GraceHashJoin: probe accumulation progress: {} rows, \
17551735 reservation={}, pool reserved={}",
17561736 probe_rows_accumulated,
1757- reservation. 0 . size( ) ,
1737+ reservation. size( ) ,
17581738 context. runtime_env( ) . memory_pool. reserved( ) ,
17591739 ) ;
17601740 }
@@ -2600,6 +2580,7 @@ mod tests {
26002580 4 , // Use 4 partitions for testing
26012581 true ,
26022582 10 * 1024 * 1024 , // 10 MB fast path threshold
2583+ 2 , // max_concurrent_partitions
26032584 ) ?;
26042585
26052586 let stream = grace_join. execute ( 0 , task_ctx) ?;
@@ -2654,6 +2635,7 @@ mod tests {
26542635 4 ,
26552636 true ,
26562637 10 * 1024 * 1024 , // 10 MB fast path threshold
2638+ 2 , // max_concurrent_partitions
26572639 ) ?;
26582640
26592641 let stream = grace_join. execute ( 0 , task_ctx) ?;
@@ -2761,6 +2743,7 @@ mod tests {
27612743 16 ,
27622744 true , // build_left
27632745 0 , // fast_path_threshold = 0 (disabled)
2746+ 2 , // max_concurrent_partitions
27642747 ) ?;
27652748
27662749 let stream = grace_join. execute ( 0 , task_ctx) ?;
@@ -2824,6 +2807,7 @@ mod tests {
28242807 16 ,
28252808 false , // build_left=false → right is build side
28262809 0 , // fast_path_threshold = 0 (disabled)
2810+ 2 , // max_concurrent_partitions
28272811 ) ?;
28282812
28292813 let stream = grace_join. execute ( 0 , task_ctx) ?;
0 commit comments