Skip to content

Commit d0bcb4e

Browse files
committed
fix: use spill writer's schema instead of the first batch schema for spill files
1 parent bc2b36c commit d0bcb4e

File tree

2 files changed

+246
-1
lines changed

2 files changed

+246
-1
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Array, Int64Array, RecordBatch};
21+
use arrow::compute::SortOptions;
22+
use arrow::datatypes::{DataType, Field, Schema};
23+
use datafusion::datasource::memory::MemorySourceConfig;
24+
use datafusion_execution::config::SessionConfig;
25+
use datafusion_execution::memory_pool::FairSpillPool;
26+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
27+
use datafusion_physical_expr::expressions::col;
28+
use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
29+
use datafusion_physical_plan::repartition::RepartitionExec;
30+
use datafusion_physical_plan::sorts::sort::sort_batch;
31+
use datafusion_physical_plan::union::UnionExec;
32+
use datafusion_physical_plan::{ExecutionPlan, Partitioning};
33+
use futures::StreamExt;
34+
35+
const NUM_BATCHES: usize = 200;
36+
const ROWS_PER_BATCH: usize = 10;
37+
38+
fn non_nullable_schema() -> Arc<Schema> {
39+
Arc::new(Schema::new(vec![
40+
Field::new("key", DataType::Int64, false),
41+
Field::new("val", DataType::Int64, false),
42+
]))
43+
}
44+
45+
fn nullable_schema() -> Arc<Schema> {
46+
Arc::new(Schema::new(vec![
47+
Field::new("key", DataType::Int64, false),
48+
Field::new("val", DataType::Int64, true),
49+
]))
50+
}
51+
52+
fn non_nullable_batches() -> Vec<RecordBatch> {
53+
(0..NUM_BATCHES)
54+
.map(|i| {
55+
let start = (i * ROWS_PER_BATCH) as i64;
56+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
57+
RecordBatch::try_new(
58+
non_nullable_schema(),
59+
vec![
60+
Arc::new(Int64Array::from(keys)),
61+
Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])),
62+
],
63+
)
64+
.unwrap()
65+
})
66+
.collect()
67+
}
68+
69+
fn nullable_batches() -> Vec<RecordBatch> {
70+
(0..NUM_BATCHES)
71+
.map(|i| {
72+
let start = (i * ROWS_PER_BATCH) as i64;
73+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
74+
let vals: Vec<Option<i64>> = (0..ROWS_PER_BATCH)
75+
.map(|j| if j % 3 == 1 { None } else { Some(j as i64) })
76+
.collect();
77+
RecordBatch::try_new(
78+
nullable_schema(),
79+
vec![
80+
Arc::new(Int64Array::from(keys)),
81+
Arc::new(Int64Array::from(vals)),
82+
],
83+
)
84+
.unwrap()
85+
})
86+
.collect()
87+
}
88+
89+
fn build_task_ctx(pool_size: usize) -> Arc<datafusion_execution::TaskContext> {
90+
let session_config = SessionConfig::new().with_batch_size(2);
91+
let runtime = RuntimeEnvBuilder::new()
92+
.with_memory_pool(Arc::new(FairSpillPool::new(pool_size)))
93+
.build_arc()
94+
.unwrap();
95+
Arc::new(
96+
datafusion_execution::TaskContext::default()
97+
.with_session_config(session_config)
98+
.with_runtime(runtime),
99+
)
100+
}
101+
102+
/// Exercises spilling through UnionExec -> RepartitionExec where union children
103+
/// have mismatched nullability (one child's `val` is non-nullable, the other's
104+
/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill.
105+
///
106+
/// UnionExec returns child streams without schema coercion, so batches from
107+
/// different children carry different per-field nullability into the shared
108+
/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable)
109+
/// schema — not the first batch's schema — so readback batches are valid.
110+
///
111+
/// Otherwise, sort_batch will panic with
112+
/// `Column 'val' is declared as non-nullable but contains null values`
113+
#[tokio::test]
114+
async fn test_sort_union_repartition_spill_mixed_nullability() {
115+
let non_nullable_exec = MemorySourceConfig::try_new_exec(
116+
&[non_nullable_batches()],
117+
non_nullable_schema(),
118+
None,
119+
)
120+
.unwrap();
121+
122+
let nullable_exec =
123+
MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None)
124+
.unwrap();
125+
126+
let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap();
127+
assert!(union_exec.schema().field(1).is_nullable());
128+
129+
let repartition = Arc::new(
130+
RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(),
131+
);
132+
133+
let task_ctx = build_task_ctx(200);
134+
let mut stream = repartition.execute(0, task_ctx).unwrap();
135+
136+
let sort_expr = LexOrdering::new(vec![PhysicalSortExpr {
137+
expr: col("key", &nullable_schema()).unwrap(),
138+
options: SortOptions::default(),
139+
}])
140+
.unwrap();
141+
142+
let mut total_rows = 0usize;
143+
let mut total_nulls = 0usize;
144+
while let Some(result) = stream.next().await {
145+
let batch = result.unwrap();
146+
147+
let batch = sort_batch(&batch, &sort_expr, None).unwrap();
148+
149+
total_rows += batch.num_rows();
150+
total_nulls += batch.column(1).null_count();
151+
}
152+
153+
assert_eq!(
154+
total_rows,
155+
NUM_BATCHES * ROWS_PER_BATCH * 2,
156+
"All rows from both UNION branches should be present"
157+
);
158+
assert!(
159+
total_nulls > 0,
160+
"Expected some null values in output (i.e. nullable batches were processed)"
161+
);
162+
}

datafusion/physical-plan/src/spill/in_progress_spill_file.rs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ impl InProgressSpillFile {
6262
));
6363
}
6464
if self.writer.is_none() {
65-
let schema = batch.schema();
65+
// Use the SpillManager's declared schema rather than the batch's schema.
66+
// Individual batches may have different schemas (e.g., different nullability)
67+
// when they come from different branches of a UnionExec. The SpillManager's
68+
// schema represents the canonical schema that all batches should conform to.
69+
let schema = self.spill_writer.schema();
70+
//let schema = batch.schema();
6671
if let Some(in_progress_file) = &mut self.in_progress_file {
6772
self.writer = Some(IPCStreamWriter::new(
6873
in_progress_file.path(),
@@ -138,3 +143,81 @@ impl InProgressSpillFile {
138143
Ok(self.in_progress_file.take())
139144
}
140145
}
146+
147+
#[cfg(test)]
148+
mod tests {
149+
use super::*;
150+
use arrow::array::Int64Array;
151+
use arrow_schema::{DataType, Field, Schema};
152+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
153+
use datafusion_physical_expr_common::metrics::{
154+
ExecutionPlanMetricsSet, SpillMetrics,
155+
};
156+
use futures::TryStreamExt;
157+
158+
#[tokio::test]
159+
async fn test_spill_file_uses_spill_manager_schema() {
160+
let nullable_schema = Arc::new(Schema::new(vec![
161+
Field::new("key", DataType::Int64, false),
162+
Field::new("val", DataType::Int64, true),
163+
]));
164+
let non_nullable_schema = Arc::new(Schema::new(vec![
165+
Field::new("key", DataType::Int64, false),
166+
Field::new("val", DataType::Int64, false),
167+
]));
168+
169+
let runtime = Arc::new(RuntimeEnvBuilder::new().build().unwrap());
170+
let metrics_set = ExecutionPlanMetricsSet::new();
171+
let spill_metrics = SpillMetrics::new(&metrics_set, 0);
172+
let spill_manager = Arc::new(SpillManager::new(
173+
runtime,
174+
spill_metrics,
175+
Arc::clone(&nullable_schema),
176+
));
177+
178+
let mut in_progress = spill_manager.create_in_progress_file("test").unwrap();
179+
180+
// First batch: non-nullable val (simulates literal-0 UNION branch)
181+
let non_nullable_batch = RecordBatch::try_new(
182+
Arc::clone(&non_nullable_schema),
183+
vec![
184+
Arc::new(Int64Array::from(vec![1, 2, 3])),
185+
Arc::new(Int64Array::from(vec![0, 0, 0])),
186+
],
187+
)
188+
.unwrap();
189+
in_progress.append_batch(&non_nullable_batch).unwrap();
190+
191+
// Second batch: nullable val with NULLs (simulates table UNION branch)
192+
let nullable_batch = RecordBatch::try_new(
193+
Arc::clone(&nullable_schema),
194+
vec![
195+
Arc::new(Int64Array::from(vec![4, 5, 6])),
196+
Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])),
197+
],
198+
)
199+
.unwrap();
200+
in_progress.append_batch(&nullable_batch).unwrap();
201+
202+
let spill_file = in_progress.finish().unwrap().unwrap();
203+
204+
let stream = spill_manager
205+
.read_spill_as_stream(spill_file, None)
206+
.unwrap();
207+
208+
// Stream schema should be nullable
209+
assert!(stream.schema().field(1).is_nullable());
210+
211+
let batches = stream.try_collect::<Vec<_>>().await.unwrap();
212+
assert_eq!(batches.len(), 2);
213+
214+
// Both batches must have the SpillManager's nullable schema
215+
assert_eq!(
216+
batches[0],
217+
non_nullable_batch
218+
.with_schema(Arc::clone(&nullable_schema))
219+
.unwrap()
220+
);
221+
assert_eq!(batches[1], nullable_batch);
222+
}
223+
}

0 commit comments

Comments
 (0)