Skip to content

Commit bc6f95b

Browse files
SubhamSinghalSubham Singhal
andauthored
fix: track join_arrays memory in reservation after SMJ spill (#21962)
## Which issue does this PR close? Related to the TODO at `materializing_stream.rs:283` (from [#17429](#17429 (comment))): spilled `BufferedBatch` join key arrays are not tracked in memory reservation. ## Rationale for this change When a `BufferedBatch` is spilled to disk in Sort Merge Join, only the `RecordBatch` data is written to the IPC file. The `join_arrays` (evaluated join key columns) remain in memory because the merge-scan comparator needs them to detect key group boundaries. Before this fix, these in-memory `join_arrays` were **invisible to the memory pool**: allocate_reservation(): try_grow(size_estimation) → FAILS (pool full) spill batch to disk → join_arrays still in memory, but reservation was never grown → pool thinks 0 bytes are used for this batch free_reservation(): if InMemory → shrink(size_estimation) if Spilled → no-op ← correct (nothing was grown), but join_arrays are invisible With many spilled batches for a skewed key (e.g., millions of rows sharing the same join key), the untracked `join_arrays` memory accumulates. The memory pool cannot account for this when making spill decisions for concurrent operators. ## What changes are included in this PR? **Memory accounting fix** (`materializing_stream.rs`): - Add `reserved_amount` field to `BufferedBatch` — tracks how much memory was **actually reserved** in the pool for this batch - Add `join_arrays_mem()` helper — computes total memory of join key arrays - `allocate_reservation()`: after spilling, calls `try_grow(join_arrays_mem)` to track the remaining in-memory data. If the pool is too tight for even that, `reserved_amount` stays 0 (best-effort, safe) - `free_reservation()`: shrinks by `reserved_amount` instead of checking `InMemory` variant. Invariant: only shrink by what was actually grown — no underflow risk | Scenario | `try_grow` | `reserved_amount` | `try_shrink` | Safe? | |----------|-----------|-------------------|-------------|-------| | InMemory | Ok(size_estimation) | size_estimation | size_estimation | Yes | | Spilled, tracked | Ok(join_arrays_mem) | join_arrays_mem | join_arrays_mem | Yes | | Spilled, pool tight | Err | 0 | 0 (no-op) | Yes | **Tests** (`tests.rs`): - `spill_many_batches_same_key` — 10+5 batches all sharing key=1, verifies correctness under heavy spilling - `spill_string_join_keys` — Utf8 join keys to exercise larger `join_arrays` footprint - `spill_mixed_keys_some_match` — multiple distinct keys with partial matching, tests Full outer join NULL rows from spilled batches - `spill_join_arrays_memory_accounting` — verifies memory pool is fully released after join completes (`memory_pool.reserved() == 0`) and `peak_mem_used > 0` ## Are these changes tested? Yes. Four new tests added covering heavy spilling with same-key batches, string join keys, mixed keys with partial matching, and memory pool accounting verification. ## Are there any user-facing changes? No. --------- Co-authored-by: Subham Singhal <subhamsinghal@Subhams-MacBook-Air.local>
1 parent 3b634aa commit bc6f95b

2 files changed

Lines changed: 258 additions & 11 deletions

File tree

datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,18 @@ pub(super) struct BufferedBatch {
235235
pub null_joined: Vec<usize>,
236236
/// Size estimation used for reserving / releasing memory
237237
pub size_estimation: usize,
238+
/// Memory footprint of `join_arrays` cached at construction time.
239+
/// Used during spill to track the residual memory that remains after
240+
/// the main batch is written to disk.
241+
pub join_arrays_mem: usize,
242+
/// Actual amount tracked in the memory reservation for this batch.
243+
///
244+
/// - `InMemory`: equals `size_estimation` (full batch + join_arrays + metadata)
245+
/// - `Spilled`: equals `join_arrays_mem` (join key arrays stay in memory)
246+
///
247+
/// Invariant: `free_reservation()` shrinks by exactly this amount, so we never
248+
/// shrink by more than we grew.
249+
pub reserved_amount: usize,
238250
/// Tracks filter outcomes for buffered rows in full outer joins.
239251
/// Indexed by absolute row position within the batch. See [`FilterState`].
240252
pub join_filter_status: Vec<FilterState>,
@@ -258,11 +270,13 @@ impl BufferedBatch {
258270
// + worst case null_joined (as vector capacity * element size)
259271
// + Range size
260272
// + size of this estimation
273+
let join_arrays_mem: usize = join_arrays
274+
.iter()
275+
.map(|arr| arr.get_array_memory_size())
276+
.sum();
277+
261278
let size_estimation = batch.get_array_memory_size()
262-
+ join_arrays
263-
.iter()
264-
.map(|arr| arr.get_array_memory_size())
265-
.sum::<usize>()
279+
+ join_arrays_mem
266280
+ batch.num_rows().next_power_of_two() * size_of::<usize>()
267281
+ size_of::<Range<usize>>()
268282
+ size_of::<usize>();
@@ -274,6 +288,8 @@ impl BufferedBatch {
274288
join_arrays,
275289
null_joined: vec![],
276290
size_estimation,
291+
join_arrays_mem,
292+
reserved_amount: 0,
277293
join_filter_status: vec![FilterState::Unvisited; num_rows],
278294
num_rows,
279295
}
@@ -947,18 +963,16 @@ impl MaterializingSortMergeJoinStream {
947963
}
948964
}
949965

950-
fn free_reservation(&mut self, buffered_batch: &BufferedBatch) -> Result<()> {
951-
// Shrink memory usage for in-memory batches only
952-
if let BufferedBatchState::InMemory(_) = buffered_batch.batch {
953-
self.reservation
954-
.try_shrink(buffered_batch.size_estimation)?;
966+
fn free_reservation(&mut self, buffered_batch: &BufferedBatch) {
967+
if buffered_batch.reserved_amount > 0 {
968+
self.reservation.shrink(buffered_batch.reserved_amount);
955969
}
956-
Ok(())
957970
}
958971

959972
fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
960973
match self.reservation.try_grow(buffered_batch.size_estimation) {
961974
Ok(_) => {
975+
buffered_batch.reserved_amount = buffered_batch.size_estimation;
962976
self.join_metrics
963977
.peak_mem_used()
964978
.set_max(self.reservation.size());
@@ -978,6 +992,22 @@ impl MaterializingSortMergeJoinStream {
978992
.unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled
979993

980994
buffered_batch.batch = BufferedBatchState::Spilled(spill_file);
995+
996+
// Join key arrays remain in memory after the batch is
997+
// spilled — the comparator needs them for key boundary
998+
// detection. Force-grow the reservation so the pool
999+
// reflects actual memory usage even if this pushes
1000+
// pool.reserved() above the configured limit. This is
1001+
// safe because the memory is physically consumed and
1002+
// not tracking it would let other operators over-allocate
1003+
// against a stale pool view.
1004+
let join_arrays_mem = buffered_batch.join_arrays_mem;
1005+
self.reservation.grow(join_arrays_mem);
1006+
buffered_batch.reserved_amount = join_arrays_mem;
1007+
self.join_metrics
1008+
.peak_mem_used()
1009+
.set_max(self.reservation.size());
1010+
9811011
Ok(())
9821012
}
9831013
_ => internal_err!("Buffered batch has empty body"),
@@ -1006,7 +1036,7 @@ impl MaterializingSortMergeJoinStream {
10061036
self.buffered_data.batches.pop_front()
10071037
{
10081038
self.produce_buffered_not_matched(&mut buffered_batch)?;
1009-
self.free_reservation(&buffered_batch)?;
1039+
self.free_reservation(&buffered_batch);
10101040
head_changed = true;
10111041
}
10121042
} else {

datafusion/physical-plan/src/joins/sort_merge_join/tests.rs

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2487,6 +2487,223 @@ async fn overallocation_multi_batch_spill() -> Result<()> {
24872487
Ok(())
24882488
}
24892489

2490+
/// Verifies that `peak_mem_used` reflects join_arrays memory on the spill path.
2491+
///
2492+
/// Uses a memory limit smaller than a single batch's `size_estimation` so that
2493+
/// every batch spills — the `Ok` arm of `allocate_reservation` is never hit.
2494+
/// Before the fix, `peak_mem_used` would stay 0 because `set_max` was only
2495+
/// called in the `Ok` arm. After the fix, the spill path calls
2496+
/// `grow(join_arrays_mem)` + `set_max`, so `peak_mem_used > 0`.
2497+
#[tokio::test]
2498+
async fn spill_join_arrays_memory_accounting() -> Result<()> {
2499+
use arrow::array::Array;
2500+
2501+
let left_batch = build_table_i32(
2502+
("a1", &vec![0, 1]),
2503+
("b1", &vec![1, 1]),
2504+
("c1", &vec![4, 5]),
2505+
);
2506+
let size_estimation = left_batch.get_array_memory_size()
2507+
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
2508+
+ 2usize.next_power_of_two() * size_of::<usize>()
2509+
+ size_of::<std::ops::Range<usize>>()
2510+
+ size_of::<usize>();
2511+
let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size();
2512+
2513+
// Memory limit: too small for a full batch, large enough for join_arrays.
2514+
// Every batch hits the Err arm → spills → grow(join_arrays_mem).
2515+
let memory_limit = (size_estimation + join_arrays_mem) / 2;
2516+
assert!(
2517+
memory_limit < size_estimation && memory_limit > join_arrays_mem,
2518+
"limit {memory_limit} must be between join_arrays_mem {join_arrays_mem} \
2519+
and size_estimation {size_estimation}"
2520+
);
2521+
2522+
let left_batches: Vec<RecordBatch> = (0..4)
2523+
.map(|i| {
2524+
build_table_i32(
2525+
("a1", &vec![i * 2, i * 2 + 1]),
2526+
("b1", &vec![1, 1]),
2527+
("c1", &vec![100 + i, 101 + i]),
2528+
)
2529+
})
2530+
.collect();
2531+
let left = build_table_from_batches(left_batches);
2532+
2533+
let right_batches: Vec<RecordBatch> = (0..2)
2534+
.map(|i| {
2535+
build_table_i32(
2536+
("a2", &vec![i * 2, i * 2 + 1]),
2537+
("b2", &vec![1, 1]),
2538+
("c2", &vec![200 + i, 201 + i]),
2539+
)
2540+
})
2541+
.collect();
2542+
let right = build_table_from_batches(right_batches);
2543+
2544+
let on = vec![(
2545+
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2546+
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2547+
)];
2548+
let sort_options = vec![SortOptions::default(); on.len()];
2549+
2550+
let runtime = RuntimeEnvBuilder::new()
2551+
.with_memory_limit(memory_limit, 1.0)
2552+
.with_disk_manager_builder(
2553+
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
2554+
)
2555+
.build_arc()?;
2556+
2557+
let session_config = SessionConfig::default().with_batch_size(50);
2558+
let task_ctx = Arc::new(
2559+
TaskContext::default()
2560+
.with_session_config(session_config)
2561+
.with_runtime(Arc::clone(&runtime)),
2562+
);
2563+
2564+
let join = join_with_options(
2565+
Arc::clone(&left),
2566+
Arc::clone(&right),
2567+
on.clone(),
2568+
Inner,
2569+
sort_options,
2570+
NullEquality::NullEqualsNothing,
2571+
)?;
2572+
2573+
let stream = join.execute(0, task_ctx)?;
2574+
let result = common::collect(stream).await.unwrap();
2575+
2576+
assert!(!result.is_empty(), "Expected non-empty join result");
2577+
2578+
let metrics = join.metrics().unwrap();
2579+
assert!(
2580+
metrics.spill_count().unwrap() > 0,
2581+
"Expected spilling to occur"
2582+
);
2583+
2584+
// Before the fix, peak_mem_used was 0 here because set_max was only
2585+
// called in the Ok arm of allocate_reservation, which is never reached
2586+
// when every batch spills. After the fix, the spill path calls
2587+
// grow(join_arrays_mem) + set_max unconditionally.
2588+
let peak_mem = metrics
2589+
.sum_by_name("peak_mem_used")
2590+
.map(|m| m.as_usize())
2591+
.unwrap_or(0);
2592+
assert!(
2593+
peak_mem >= join_arrays_mem,
2594+
"peak_mem_used ({peak_mem}) should be >= join_arrays_mem ({join_arrays_mem})"
2595+
);
2596+
2597+
// All memory must be released (grow/shrink balanced, no underflow)
2598+
assert_eq!(
2599+
runtime.memory_pool.reserved(),
2600+
0,
2601+
"All memory should be released after join completes"
2602+
);
2603+
2604+
Ok(())
2605+
}
2606+
2607+
/// Test the no-headroom scenario: pool is so tight that even
2608+
/// join_arrays_mem exceeds the pool limit. With force-grow, the
2609+
/// reservation still tracks the join_arrays unconditionally so the
2610+
/// pool reflects actual memory usage.
2611+
#[tokio::test]
2612+
async fn spill_join_arrays_no_headroom() -> Result<()> {
2613+
use arrow::array::Array;
2614+
2615+
let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size();
2616+
2617+
// Pool smaller than join_arrays_mem: try_grow(size_estimation) fails → spill.
2618+
// Force-grow(join_arrays_mem) succeeds unconditionally → reserved_amount > 0.
2619+
let memory_limit = join_arrays_mem / 2;
2620+
assert!(
2621+
memory_limit < join_arrays_mem,
2622+
"limit {memory_limit} must be smaller than join_arrays_mem {join_arrays_mem}"
2623+
);
2624+
2625+
let left_batches: Vec<RecordBatch> = (0..4)
2626+
.map(|i| {
2627+
build_table_i32(
2628+
("a1", &vec![i * 2, i * 2 + 1]),
2629+
("b1", &vec![1, 1]),
2630+
("c1", &vec![100 + i, 101 + i]),
2631+
)
2632+
})
2633+
.collect();
2634+
let left = build_table_from_batches(left_batches);
2635+
2636+
let right_batches: Vec<RecordBatch> = (0..2)
2637+
.map(|i| {
2638+
build_table_i32(
2639+
("a2", &vec![i * 2, i * 2 + 1]),
2640+
("b2", &vec![1, 1]),
2641+
("c2", &vec![200 + i, 201 + i]),
2642+
)
2643+
})
2644+
.collect();
2645+
let right = build_table_from_batches(right_batches);
2646+
2647+
let on = vec![(
2648+
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
2649+
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
2650+
)];
2651+
let sort_options = vec![SortOptions::default(); on.len()];
2652+
2653+
let runtime = RuntimeEnvBuilder::new()
2654+
.with_memory_limit(memory_limit, 1.0)
2655+
.with_disk_manager_builder(
2656+
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
2657+
)
2658+
.build_arc()?;
2659+
2660+
let session_config = SessionConfig::default().with_batch_size(50);
2661+
let task_ctx = Arc::new(
2662+
TaskContext::default()
2663+
.with_session_config(session_config)
2664+
.with_runtime(Arc::clone(&runtime)),
2665+
);
2666+
2667+
let join = join_with_options(
2668+
Arc::clone(&left),
2669+
Arc::clone(&right),
2670+
on.clone(),
2671+
Inner,
2672+
sort_options,
2673+
NullEquality::NullEqualsNothing,
2674+
)?;
2675+
2676+
let stream = join.execute(0, task_ctx)?;
2677+
let result = common::collect(stream).await.unwrap();
2678+
2679+
assert!(!result.is_empty(), "Expected non-empty join result");
2680+
2681+
let metrics = join.metrics().unwrap();
2682+
assert!(
2683+
metrics.spill_count().unwrap() > 0,
2684+
"Expected spilling to occur"
2685+
);
2686+
2687+
// Force-grow means peak_mem_used is always tracked, even when pool is tight.
2688+
let peak_mem = metrics
2689+
.sum_by_name("peak_mem_used")
2690+
.map(|m| m.as_usize())
2691+
.unwrap_or(0);
2692+
assert!(
2693+
peak_mem >= join_arrays_mem,
2694+
"peak_mem_used ({peak_mem}) should be >= join_arrays_mem ({join_arrays_mem})"
2695+
);
2696+
2697+
// Pool should be fully released (grow/shrink balanced)
2698+
assert_eq!(
2699+
runtime.memory_pool.reserved(),
2700+
0,
2701+
"All memory should be released after join completes"
2702+
);
2703+
2704+
Ok(())
2705+
}
2706+
24902707
/// Build a c1 < c2 filter on the third column of each side.
24912708
fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter {
24922709
JoinFilter::new(

0 commit comments

Comments
 (0)