Skip to content

Commit 635b511

Browse files
committed
fix: use spill writer's schema instead of the first batch schema for spill files
1 parent bc2b36c commit 635b511

File tree

2 files changed

+273
-1
lines changed

2 files changed

+273
-1
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Array, Int64Array, RecordBatch};
21+
use arrow::compute::SortOptions;
22+
use arrow::datatypes::{DataType, Field, Schema};
23+
use datafusion::datasource::memory::MemorySourceConfig;
24+
use datafusion_execution::config::SessionConfig;
25+
use datafusion_execution::memory_pool::FairSpillPool;
26+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
27+
use datafusion_physical_expr::expressions::col;
28+
use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
29+
use datafusion_physical_plan::repartition::RepartitionExec;
30+
use datafusion_physical_plan::sorts::sort::sort_batch;
31+
use datafusion_physical_plan::union::UnionExec;
32+
use datafusion_physical_plan::{ExecutionPlan, Partitioning};
33+
use futures::StreamExt;
34+
35+
const NUM_BATCHES: usize = 200;
36+
const ROWS_PER_BATCH: usize = 10;
37+
38+
fn non_nullable_schema() -> Arc<Schema> {
39+
Arc::new(Schema::new(vec![
40+
Field::new("key", DataType::Int64, false),
41+
Field::new("val", DataType::Int64, false),
42+
]))
43+
}
44+
45+
fn nullable_schema() -> Arc<Schema> {
46+
Arc::new(Schema::new(vec![
47+
Field::new("key", DataType::Int64, false),
48+
Field::new("val", DataType::Int64, true),
49+
]))
50+
}
51+
52+
fn non_nullable_batches() -> Vec<RecordBatch> {
53+
(0..NUM_BATCHES)
54+
.map(|i| {
55+
let start = (i * ROWS_PER_BATCH) as i64;
56+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
57+
RecordBatch::try_new(
58+
non_nullable_schema(),
59+
vec![
60+
Arc::new(Int64Array::from(keys)),
61+
Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])),
62+
],
63+
)
64+
.unwrap()
65+
})
66+
.collect()
67+
}
68+
69+
fn nullable_batches() -> Vec<RecordBatch> {
70+
(0..NUM_BATCHES)
71+
.map(|i| {
72+
let start = (i * ROWS_PER_BATCH) as i64;
73+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
74+
let vals: Vec<Option<i64>> = (0..ROWS_PER_BATCH)
75+
.map(|j| if j % 3 == 1 { None } else { Some(j as i64) })
76+
.collect();
77+
RecordBatch::try_new(
78+
nullable_schema(),
79+
vec![
80+
Arc::new(Int64Array::from(keys)),
81+
Arc::new(Int64Array::from(vals)),
82+
],
83+
)
84+
.unwrap()
85+
})
86+
.collect()
87+
}
88+
89+
fn build_task_ctx(pool_size: usize) -> Arc<datafusion_execution::TaskContext> {
90+
let session_config = SessionConfig::new().with_batch_size(2);
91+
let runtime = RuntimeEnvBuilder::new()
92+
.with_memory_pool(Arc::new(FairSpillPool::new(pool_size)))
93+
.build_arc()
94+
.unwrap();
95+
Arc::new(
96+
datafusion_execution::TaskContext::default()
97+
.with_session_config(session_config)
98+
.with_runtime(runtime),
99+
)
100+
}
101+
102+
/// Exercises spilling through UnionExec -> RepartitionExec where union children
103+
/// have mismatched nullability (one child's `val` is non-nullable, the other's
104+
/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill.
105+
///
106+
/// UnionExec returns child streams without schema coercion, so batches from
107+
/// different children carry different per-field nullability into the shared
108+
/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable)
109+
/// schema — not the first batch's schema — so readback batches are valid.
110+
///
111+
/// Otherwise, sort_batch will panic with
112+
/// `Column 'val' is declared as non-nullable but contains null values`
113+
#[tokio::test]
114+
async fn test_sort_union_repartition_spill_mixed_nullability() {
115+
let non_nullable_exec = MemorySourceConfig::try_new_exec(
116+
&[non_nullable_batches()],
117+
non_nullable_schema(),
118+
None,
119+
)
120+
.unwrap();
121+
122+
let nullable_exec =
123+
MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None)
124+
.unwrap();
125+
126+
let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap();
127+
assert!(union_exec.schema().field(1).is_nullable());
128+
129+
let repartition = Arc::new(
130+
RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(),
131+
);
132+
133+
let task_ctx = build_task_ctx(200);
134+
let mut stream = repartition.execute(0, task_ctx).unwrap();
135+
136+
let sort_expr = LexOrdering::new(vec![PhysicalSortExpr {
137+
expr: col("key", &nullable_schema()).unwrap(),
138+
options: SortOptions::default(),
139+
}])
140+
.unwrap();
141+
142+
let mut total_rows = 0usize;
143+
let mut total_nulls = 0usize;
144+
while let Some(result) = stream.next().await {
145+
let batch = result.unwrap();
146+
147+
let batch = sort_batch(&batch, &sort_expr, None).unwrap();
148+
149+
total_rows += batch.num_rows();
150+
total_nulls += batch.column(1).null_count();
151+
}
152+
153+
assert_eq!(
154+
total_rows,
155+
NUM_BATCHES * ROWS_PER_BATCH * 2,
156+
"All rows from both UNION branches should be present"
157+
);
158+
assert!(
159+
total_nulls > 0,
160+
"Expected some null values in output (i.e. nullable batches were processed)"
161+
);
162+
}

datafusion/physical-plan/src/spill/in_progress_spill_file.rs

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ impl InProgressSpillFile {
6262
));
6363
}
6464
if self.writer.is_none() {
65-
let schema = batch.schema();
65+
// Use the SpillManager's declared schema rather than the batch's schema.
66+
// Individual batches may have different schemas (e.g., different nullability)
67+
// when they come from different branches of a UnionExec. The SpillManager's
68+
// schema represents the canonical schema that all batches should conform to.
69+
let schema = self.spill_writer.schema();
6670
if let Some(in_progress_file) = &mut self.in_progress_file {
6771
self.writer = Some(IPCStreamWriter::new(
6872
in_progress_file.path(),
@@ -138,3 +142,109 @@ impl InProgressSpillFile {
138142
Ok(self.in_progress_file.take())
139143
}
140144
}
145+
146+
#[cfg(test)]
147+
mod tests {
148+
use super::*;
149+
use arrow::array::{Array, Int64Array};
150+
use arrow_schema::{DataType, Field, Schema};
151+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
152+
use datafusion_physical_expr_common::metrics::{
153+
ExecutionPlanMetricsSet, SpillMetrics,
154+
};
155+
use futures::StreamExt;
156+
157+
/// Unit-level test: proves that InProgressSpillFile uses the SpillManager's
158+
/// declared schema for the IPC writer, so readback batches always have the
159+
/// correct schema even when input batches have mismatched nullability.
160+
///
161+
/// Scenario:
162+
/// - SpillManager declares schema with nullable `val`
163+
/// - First appended batch has non-nullable `val` (simulates literal-projection UNION branch)
164+
/// - Second appended batch has nullable `val` with NULLs (simulates table UNION branch)
165+
/// - On readback, both batches must have the nullable schema
166+
#[tokio::test]
167+
async fn test_spill_file_uses_spill_manager_schema() {
168+
let nullable_schema = Arc::new(Schema::new(vec![
169+
Field::new("key", DataType::Int64, false),
170+
Field::new("val", DataType::Int64, true),
171+
]));
172+
let non_nullable_schema = Arc::new(Schema::new(vec![
173+
Field::new("key", DataType::Int64, false),
174+
Field::new("val", DataType::Int64, false),
175+
]));
176+
177+
let runtime = Arc::new(RuntimeEnvBuilder::new().build().unwrap());
178+
let metrics_set = ExecutionPlanMetricsSet::new();
179+
let spill_metrics = SpillMetrics::new(&metrics_set, 0);
180+
let spill_manager = Arc::new(SpillManager::new(
181+
runtime,
182+
spill_metrics,
183+
Arc::clone(&nullable_schema),
184+
));
185+
186+
let mut in_progress = spill_manager.create_in_progress_file("test").unwrap();
187+
188+
// First batch: non-nullable val (simulates literal-0 UNION branch)
189+
let non_nullable_batch = RecordBatch::try_new(
190+
Arc::clone(&non_nullable_schema),
191+
vec![
192+
Arc::new(Int64Array::from(vec![1, 2, 3])),
193+
Arc::new(Int64Array::from(vec![0, 0, 0])),
194+
],
195+
)
196+
.unwrap();
197+
in_progress.append_batch(&non_nullable_batch).unwrap();
198+
199+
// Second batch: nullable val with NULLs (simulates table UNION branch)
200+
let nullable_batch = RecordBatch::try_new(
201+
Arc::clone(&nullable_schema),
202+
vec![
203+
Arc::new(Int64Array::from(vec![4, 5, 6])),
204+
Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])),
205+
],
206+
)
207+
.unwrap();
208+
in_progress.append_batch(&nullable_batch).unwrap();
209+
210+
let spill_file = in_progress.finish().unwrap().unwrap();
211+
212+
// Read back
213+
let mut stream = spill_manager
214+
.read_spill_as_stream(spill_file, None)
215+
.unwrap();
216+
217+
assert!(
218+
stream.schema().field(1).is_nullable(),
219+
"Stream schema should be nullable"
220+
);
221+
222+
let mut batches = vec![];
223+
while let Some(result) = stream.next().await {
224+
batches.push(result.unwrap());
225+
}
226+
assert_eq!(batches.len(), 2);
227+
228+
// Both readback batches must have the SpillManager's nullable schema
229+
for (i, batch) in batches.iter().enumerate() {
230+
assert!(
231+
batch.schema().field(1).is_nullable(),
232+
"Readback batch {i} should have nullable schema from SpillManager"
233+
);
234+
}
235+
236+
// The second batch must preserve its NULL data
237+
let val_col = batches[1]
238+
.column(1)
239+
.as_any()
240+
.downcast_ref::<Int64Array>()
241+
.unwrap();
242+
assert_eq!(val_col.null_count(), 1, "Second batch should have 1 null");
243+
244+
// Rebuilding the batch with its own schema must succeed (would fail if
245+
// schema said non-nullable but data contained nulls)
246+
RecordBatch::try_new(batches[1].schema(), batches[1].columns().to_vec()).expect(
247+
"Readback batch should be valid: schema should match data nullability",
248+
);
249+
}
250+
}

0 commit comments

Comments
 (0)