Skip to content

Commit b17847d

Browse files
Track spill read-back memory in SMJ (#22103)
## Which issue does this PR close? Follow-up to [#21962](#21962). ## Rationale for this change After #21962, the memory pool accurately tracks residual `join_arrays` memory that remains after a `BufferedBatch` is spilled to disk. However, when spilled batches are **read back** from disk during output materialization in `materialize_right_columns`, the deserialized data temporarily exists in memory without any pool reservation. - **Single-source path**: one full batch loaded without reservation - **Multi-source interleave path**: ALL referenced spilled batches loaded simultaneously — N × batch_size untracked The pool thinks these batches cost 0 bytes during read-back. Under memory pressure (the reason they were spilled), other operators see stale headroom and may over-allocate, risking OOM. ## What changes are included in this PR? Changed `materialize_right_columns` from `&self` to `&mut self` and added `grow/shrink` at the exact points where spilled data is read from disk: **Path A (single source spilled):** - `grow(size_estimation)` immediately before `fetch_right_columns_by_idxs` - `shrink(size_estimation)` immediately after **Path B (multi-source interleave):** - Sum `size_estimation` for all spilled sources - `grow(total)` before `source_data` loading - `shrink(total)` after interleave completes Uses unconditional `grow()` because the data must be read to produce output — there is no fallback. Same rationale as #21962: if memory physically exists, the pool must reflect it. ## Are these changes tested? Yes — two new tests: - `spill_read_back_memory_accounting`: multiple buffered batches for same key (multi-source Path B) — verifies `peak_mem_used >= size_estimation` and `pool.reserved() == 0` at end - `spill_read_back_single_source`: distinct keys with one batch per group (single-source Path A) — same assertions ## Are there any user-facing changes? No. --------- Co-authored-by: Kumar Ujjawal <ujjawalpathak6@gmail.com>
1 parent b4a6eb1 commit b17847d

2 files changed

Lines changed: 260 additions & 16 deletions

File tree

datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,7 +1540,7 @@ impl MaterializingSortMergeJoinStream {
15401540
/// gathers columns across sources. A null-row sentinel at source index 0
15411541
/// handles null right indices (unmatched streamed rows).
15421542
fn materialize_right_columns(
1543-
&self,
1543+
&mut self,
15441544
matched_chunks: &[(usize, UInt64Array, UInt64Array)],
15451545
total_matched_rows: usize,
15461546
) -> Result<Vec<ArrayRef>> {
@@ -1555,6 +1555,19 @@ impl MaterializingSortMergeJoinStream {
15551555
matched_chunks.iter().map(|c| &c.2 as &dyn Array).collect();
15561556
as_uint64_array(&compute::concat(&refs)?)?.clone()
15571557
};
1558+
1559+
let spill_reservation = self.reservation.new_empty();
1560+
if matches!(
1561+
&self.buffered_data.batches[first_batch_idx].batch,
1562+
BufferedBatchState::Spilled(_)
1563+
) {
1564+
spill_reservation
1565+
.grow(self.buffered_data.batches[first_batch_idx].size_estimation);
1566+
self.join_metrics
1567+
.peak_mem_used()
1568+
.set_max(self.reservation.size() + spill_reservation.size());
1569+
}
1570+
15581571
return fetch_right_columns_by_idxs(
15591572
&self.buffered_data,
15601573
first_batch_idx,
@@ -1588,24 +1601,33 @@ impl MaterializingSortMergeJoinStream {
15881601
}
15891602

15901603
let num_right_cols = self.buffered_schema.fields().len();
1591-
let mut right_columns = Vec::with_capacity(num_right_cols);
15921604

15931605
// Read each source batch once (spilled batches require disk I/O).
1594-
let source_data: Vec<Option<RecordBatch>> = source_batches
1595-
.iter()
1596-
.map(|&idx| {
1597-
let bb = &self.buffered_data.batches[idx];
1598-
match &bb.batch {
1599-
BufferedBatchState::InMemory(batch) => Some(batch.clone()),
1600-
BufferedBatchState::Spilled(spill_file) => {
1601-
let file = BufReader::new(File::open(spill_file.path()).ok()?);
1602-
let reader = StreamReader::try_new(file, None).ok()?;
1603-
reader.into_iter().next()?.ok()
1604-
}
1606+
// Track memory for each spilled batch at the point of deserialization
1607+
// so the pool reflects actual usage as it grows.
1608+
let spill_reservation = self.reservation.new_empty();
1609+
let mut source_data: Vec<Option<RecordBatch>> =
1610+
Vec::with_capacity(source_batches.len());
1611+
for &idx in &source_batches {
1612+
let bb = &self.buffered_data.batches[idx];
1613+
match &bb.batch {
1614+
BufferedBatchState::InMemory(batch) => {
1615+
source_data.push(Some(batch.clone()));
16051616
}
1606-
})
1607-
.collect();
1617+
BufferedBatchState::Spilled(spill_file) => {
1618+
spill_reservation.grow(bb.size_estimation);
1619+
self.join_metrics
1620+
.peak_mem_used()
1621+
.set_max(self.reservation.size() + spill_reservation.size());
1622+
1623+
let file = BufReader::new(File::open(spill_file.path())?);
1624+
let reader = StreamReader::try_new(file, None)?;
1625+
source_data.push(reader.into_iter().next().transpose()?);
1626+
}
1627+
}
1628+
}
16081629

1630+
let mut right_columns = Vec::with_capacity(num_right_cols);
16091631
for col_idx in 0..num_right_cols {
16101632
let dtype = self.buffered_schema.field(col_idx).data_type();
16111633
let null_array = new_null_array(dtype, 1);
@@ -1624,7 +1646,6 @@ impl MaterializingSortMergeJoinStream {
16241646
}
16251647
}
16261648
}
1627-
16281649
right_columns.push(interleave(&source_arrays, &interleave_indices)?);
16291650
}
16301651

datafusion/physical-plan/src/joins/sort_merge_join/tests.rs

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4724,3 +4724,226 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> {
47244724

47254725
Ok(())
47264726
}
4727+
4728+
/// Verifies that `peak_mem_used` reflects spill read-back memory during
4729+
/// output materialization (multi-source path).
4730+
///
4731+
/// When spilled buffered batches are read back from disk to produce join
4732+
/// output, a scoped `MemoryReservation` (via `new_empty()`) tracks the
4733+
/// transient memory. Its `Drop` guarantees the pool is balanced on every
4734+
/// exit path — normal return or early `?` error.
4735+
#[tokio::test]
4736+
async fn spill_read_back_memory_accounting() -> Result<()> {
4737+
use arrow::array::Array;
4738+
4739+
let left_batch = build_table_i32(
4740+
("a1", &vec![0, 1]),
4741+
("b1", &vec![1, 1]),
4742+
("c1", &vec![4, 5]),
4743+
);
4744+
let size_estimation = left_batch.get_array_memory_size()
4745+
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
4746+
+ 2usize.next_power_of_two() * size_of::<usize>()
4747+
+ size_of::<std::ops::Range<usize>>()
4748+
+ size_of::<usize>();
4749+
4750+
// Memory limit too small for a full batch — forces spilling.
4751+
let memory_limit = size_estimation / 2;
4752+
4753+
// All rows share the same join key (b=1) to force multiple buffered
4754+
// batches in the same key group — triggering spill read-back during
4755+
// output materialization.
4756+
let left_batches: Vec<RecordBatch> = (0..4)
4757+
.map(|i| {
4758+
build_table_i32(
4759+
("a1", &vec![i * 2, i * 2 + 1]),
4760+
("b1", &vec![1, 1]),
4761+
("c1", &vec![100 + i, 101 + i]),
4762+
)
4763+
})
4764+
.collect();
4765+
let left = build_table_from_batches(left_batches);
4766+
4767+
let right_batches: Vec<RecordBatch> = (0..4)
4768+
.map(|i| {
4769+
build_table_i32(
4770+
("a2", &vec![i * 2, i * 2 + 1]),
4771+
("b2", &vec![1, 1]),
4772+
("c2", &vec![200 + i, 201 + i]),
4773+
)
4774+
})
4775+
.collect();
4776+
let right = build_table_from_batches(right_batches);
4777+
4778+
let on = vec![(
4779+
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
4780+
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
4781+
)];
4782+
let sort_options = vec![SortOptions::default(); on.len()];
4783+
4784+
let runtime = RuntimeEnvBuilder::new()
4785+
.with_memory_limit(memory_limit, 1.0)
4786+
.with_disk_manager_builder(
4787+
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
4788+
)
4789+
.build_arc()?;
4790+
4791+
let session_config = SessionConfig::default().with_batch_size(50);
4792+
let task_ctx = Arc::new(
4793+
TaskContext::default()
4794+
.with_session_config(session_config)
4795+
.with_runtime(Arc::clone(&runtime)),
4796+
);
4797+
4798+
let join = join_with_options(
4799+
Arc::clone(&left),
4800+
Arc::clone(&right),
4801+
on.clone(),
4802+
Inner,
4803+
sort_options,
4804+
NullEquality::NullEqualsNothing,
4805+
)?;
4806+
4807+
let stream = join.execute(0, task_ctx)?;
4808+
let result = common::collect(stream).await.unwrap();
4809+
4810+
assert!(!result.is_empty(), "Expected non-empty join result");
4811+
4812+
let metrics = join.metrics().unwrap();
4813+
assert!(
4814+
metrics.spill_count().unwrap() > 0,
4815+
"Expected spilling to occur"
4816+
);
4817+
4818+
// peak_mem_used should reflect the spill read-back: when buffered
4819+
// batches are read from disk during output materialization, grow()
4820+
// temporarily reserves size_estimation. This pushes peak above what
4821+
// join_arrays_mem alone would show.
4822+
let peak_mem = metrics
4823+
.sum_by_name("peak_mem_used")
4824+
.map(|m| m.as_usize())
4825+
.unwrap_or(0);
4826+
assert!(
4827+
peak_mem >= size_estimation,
4828+
"peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \
4829+
because spill read-back temporarily loads full batch into memory"
4830+
);
4831+
4832+
// All memory must be released (grow/shrink balanced)
4833+
assert_eq!(
4834+
runtime.memory_pool.reserved(),
4835+
0,
4836+
"All memory should be released after join completes"
4837+
);
4838+
4839+
Ok(())
4840+
}
4841+
4842+
/// Verifies spill read-back memory tracking for the single-source path.
4843+
///
4844+
/// When only ONE buffered batch exists for a key group and it's spilled,
4845+
/// `fetch_right_columns_by_idxs` reads it back. A scoped `MemoryReservation`
4846+
/// (via `new_empty()`) tracks the transient memory and releases it on drop.
4847+
#[tokio::test]
4848+
async fn spill_read_back_single_source() -> Result<()> {
4849+
use arrow::array::Array;
4850+
4851+
let left_batch = build_table_i32(
4852+
("a1", &vec![0, 1]),
4853+
("b1", &vec![1, 1]),
4854+
("c1", &vec![4, 5]),
4855+
);
4856+
let size_estimation = left_batch.get_array_memory_size()
4857+
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
4858+
+ 2usize.next_power_of_two() * size_of::<usize>()
4859+
+ size_of::<std::ops::Range<usize>>()
4860+
+ size_of::<usize>();
4861+
4862+
// Memory limit too small for a full batch — forces spilling.
4863+
let memory_limit = size_estimation / 2;
4864+
4865+
// Multiple distinct keys so each key group has exactly ONE buffered batch.
4866+
// This ensures the single-source path is exercised.
4867+
let left_batches: Vec<RecordBatch> = (0..4)
4868+
.map(|i| {
4869+
build_table_i32(
4870+
("a1", &vec![i * 2, i * 2 + 1]),
4871+
("b1", &vec![i, i]),
4872+
("c1", &vec![100 + i, 101 + i]),
4873+
)
4874+
})
4875+
.collect();
4876+
let left = build_table_from_batches(left_batches);
4877+
4878+
// One batch per key — each key group has single source
4879+
let right_batches: Vec<RecordBatch> = (0..4)
4880+
.map(|i| {
4881+
build_table_i32(
4882+
("a2", &vec![i * 2, i * 2 + 1]),
4883+
("b2", &vec![i, i]),
4884+
("c2", &vec![200 + i, 201 + i]),
4885+
)
4886+
})
4887+
.collect();
4888+
let right = build_table_from_batches(right_batches);
4889+
4890+
let on = vec![(
4891+
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
4892+
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
4893+
)];
4894+
let sort_options = vec![SortOptions::default(); on.len()];
4895+
4896+
let runtime = RuntimeEnvBuilder::new()
4897+
.with_memory_limit(memory_limit, 1.0)
4898+
.with_disk_manager_builder(
4899+
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
4900+
)
4901+
.build_arc()?;
4902+
4903+
let session_config = SessionConfig::default().with_batch_size(50);
4904+
let task_ctx = Arc::new(
4905+
TaskContext::default()
4906+
.with_session_config(session_config)
4907+
.with_runtime(Arc::clone(&runtime)),
4908+
);
4909+
4910+
let join = join_with_options(
4911+
Arc::clone(&left),
4912+
Arc::clone(&right),
4913+
on.clone(),
4914+
Inner,
4915+
sort_options,
4916+
NullEquality::NullEqualsNothing,
4917+
)?;
4918+
4919+
let stream = join.execute(0, task_ctx)?;
4920+
let result = common::collect(stream).await.unwrap();
4921+
4922+
assert!(!result.is_empty(), "Expected non-empty join result");
4923+
4924+
let metrics = join.metrics().unwrap();
4925+
assert!(
4926+
metrics.spill_count().unwrap() > 0,
4927+
"Expected spilling to occur"
4928+
);
4929+
4930+
// peak_mem_used should reflect the single-batch read-back
4931+
let peak_mem = metrics
4932+
.sum_by_name("peak_mem_used")
4933+
.map(|m| m.as_usize())
4934+
.unwrap_or(0);
4935+
assert!(
4936+
peak_mem >= size_estimation,
4937+
"peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \
4938+
because single-source spill read-back loads full batch"
4939+
);
4940+
4941+
// All memory must be released
4942+
assert_eq!(
4943+
runtime.memory_pool.reserved(),
4944+
0,
4945+
"All memory should be released after join completes"
4946+
);
4947+
4948+
Ok(())
4949+
}

0 commit comments

Comments
 (0)