Skip to content

Commit 8fbd95e

Browse files
committed
fix: use spill writer's schema instead of the first batch schema for spill files
1 parent 1e93a67 commit 8fbd95e

File tree

2 files changed

+241
-1
lines changed

2 files changed

+241
-1
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Array, Int64Array, RecordBatch};
21+
use arrow::compute::SortOptions;
22+
use arrow::datatypes::{DataType, Field, Schema};
23+
use datafusion::datasource::memory::MemorySourceConfig;
24+
use datafusion_execution::config::SessionConfig;
25+
use datafusion_execution::memory_pool::FairSpillPool;
26+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
27+
use datafusion_physical_expr::expressions::col;
28+
use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
29+
use datafusion_physical_plan::repartition::RepartitionExec;
30+
use datafusion_physical_plan::sorts::sort::sort_batch;
31+
use datafusion_physical_plan::union::UnionExec;
32+
use datafusion_physical_plan::{ExecutionPlan, Partitioning};
33+
use futures::StreamExt;
34+
35+
const NUM_BATCHES: usize = 200;
36+
const ROWS_PER_BATCH: usize = 10;
37+
38+
fn non_nullable_schema() -> Arc<Schema> {
39+
Arc::new(Schema::new(vec![
40+
Field::new("key", DataType::Int64, false),
41+
Field::new("val", DataType::Int64, false),
42+
]))
43+
}
44+
45+
fn nullable_schema() -> Arc<Schema> {
46+
Arc::new(Schema::new(vec![
47+
Field::new("key", DataType::Int64, false),
48+
Field::new("val", DataType::Int64, true),
49+
]))
50+
}
51+
52+
fn non_nullable_batches() -> Vec<RecordBatch> {
53+
(0..NUM_BATCHES)
54+
.map(|i| {
55+
let start = (i * ROWS_PER_BATCH) as i64;
56+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
57+
RecordBatch::try_new(
58+
non_nullable_schema(),
59+
vec![
60+
Arc::new(Int64Array::from(keys)),
61+
Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])),
62+
],
63+
)
64+
.unwrap()
65+
})
66+
.collect()
67+
}
68+
69+
fn nullable_batches() -> Vec<RecordBatch> {
70+
(0..NUM_BATCHES)
71+
.map(|i| {
72+
let start = (i * ROWS_PER_BATCH) as i64;
73+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
74+
let vals: Vec<Option<i64>> = (0..ROWS_PER_BATCH)
75+
.map(|j| if j % 3 == 1 { None } else { Some(j as i64) })
76+
.collect();
77+
RecordBatch::try_new(
78+
nullable_schema(),
79+
vec![
80+
Arc::new(Int64Array::from(keys)),
81+
Arc::new(Int64Array::from(vals)),
82+
],
83+
)
84+
.unwrap()
85+
})
86+
.collect()
87+
}
88+
89+
fn build_task_ctx(pool_size: usize) -> Arc<datafusion_execution::TaskContext> {
90+
let session_config = SessionConfig::new().with_batch_size(2);
91+
let runtime = RuntimeEnvBuilder::new()
92+
.with_memory_pool(Arc::new(FairSpillPool::new(pool_size)))
93+
.build_arc()
94+
.unwrap();
95+
Arc::new(
96+
datafusion_execution::TaskContext::default()
97+
.with_session_config(session_config)
98+
.with_runtime(runtime),
99+
)
100+
}
101+
102+
/// Exercises spilling through UnionExec -> RepartitionExec where union children
103+
/// have mismatched nullability (one child's `val` is non-nullable, the other's
104+
/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill.
105+
///
106+
/// UnionExec returns child streams without schema coercion, so batches from
107+
/// different children carry different per-field nullability into the shared
108+
/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable)
109+
/// schema — not the first batch's schema — so readback batches are valid.
110+
///
111+
/// Otherwise, sort_batch will panic with
112+
/// `Column 'val' is declared as non-nullable but contains null values`
113+
#[tokio::test]
114+
async fn test_sort_union_repartition_spill_mixed_nullability() {
115+
let non_nullable_exec = MemorySourceConfig::try_new_exec(
116+
&[non_nullable_batches()],
117+
non_nullable_schema(),
118+
None,
119+
)
120+
.unwrap();
121+
122+
let nullable_exec =
123+
MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None)
124+
.unwrap();
125+
126+
let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap();
127+
assert!(union_exec.schema().field(1).is_nullable());
128+
129+
let repartition = Arc::new(
130+
RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(),
131+
);
132+
133+
let task_ctx = build_task_ctx(200);
134+
let mut stream = repartition.execute(0, task_ctx).unwrap();
135+
136+
let sort_expr = LexOrdering::new(vec![PhysicalSortExpr {
137+
expr: col("key", &nullable_schema()).unwrap(),
138+
options: SortOptions::default(),
139+
}])
140+
.unwrap();
141+
142+
let mut total_rows = 0usize;
143+
let mut total_nulls = 0usize;
144+
while let Some(result) = stream.next().await {
145+
let batch = result.unwrap();
146+
147+
let batch = sort_batch(&batch, &sort_expr, None).unwrap();
148+
149+
total_rows += batch.num_rows();
150+
total_nulls += batch.column(1).null_count();
151+
}
152+
153+
assert_eq!(
154+
total_rows,
155+
NUM_BATCHES * ROWS_PER_BATCH * 2,
156+
"All rows from both UNION branches should be present"
157+
);
158+
assert!(
159+
total_nulls > 0,
160+
"Expected some null values in output (i.e. nullable batches were processed)"
161+
);
162+
}

datafusion/physical-plan/src/spill/in_progress_spill_file.rs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ impl InProgressSpillFile {
6262
));
6363
}
6464
if self.writer.is_none() {
65-
let schema = batch.schema();
65+
// Use the SpillManager's declared schema rather than the batch's schema.
66+
// Individual batches may have different schemas (e.g., different nullability)
67+
// when they come from different branches of a UnionExec. The SpillManager's
68+
// schema represents the canonical schema that all batches should conform to.
69+
let schema = self.spill_writer.schema();
6670
if let Some(in_progress_file) = &mut self.in_progress_file {
6771
self.writer = Some(IPCStreamWriter::new(
6872
in_progress_file.path(),
@@ -138,3 +142,77 @@ impl InProgressSpillFile {
138142
Ok(self.in_progress_file.take())
139143
}
140144
}
145+
146+
#[cfg(test)]
147+
mod tests {
148+
use super::*;
149+
use arrow::array::Int64Array;
150+
use arrow_schema::{DataType, Field, Schema};
151+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
152+
use datafusion_physical_expr_common::metrics::{
153+
ExecutionPlanMetricsSet, SpillMetrics,
154+
};
155+
use futures::TryStreamExt;
156+
157+
#[tokio::test]
158+
async fn test_spill_file_uses_spill_manager_schema() -> Result<()> {
159+
let nullable_schema = Arc::new(Schema::new(vec![
160+
Field::new("key", DataType::Int64, false),
161+
Field::new("val", DataType::Int64, true),
162+
]));
163+
let non_nullable_schema = Arc::new(Schema::new(vec![
164+
Field::new("key", DataType::Int64, false),
165+
Field::new("val", DataType::Int64, false),
166+
]));
167+
168+
let runtime = Arc::new(RuntimeEnvBuilder::new().build()?);
169+
let metrics_set = ExecutionPlanMetricsSet::new();
170+
let spill_metrics = SpillMetrics::new(&metrics_set, 0);
171+
let spill_manager = Arc::new(SpillManager::new(
172+
runtime,
173+
spill_metrics,
174+
Arc::clone(&nullable_schema),
175+
));
176+
177+
let mut in_progress = spill_manager.create_in_progress_file("test")?;
178+
179+
// First batch: non-nullable val (simulates literal-0 UNION branch)
180+
let non_nullable_batch = RecordBatch::try_new(
181+
Arc::clone(&non_nullable_schema),
182+
vec![
183+
Arc::new(Int64Array::from(vec![1, 2, 3])),
184+
Arc::new(Int64Array::from(vec![0, 0, 0])),
185+
],
186+
)?;
187+
in_progress.append_batch(&non_nullable_batch)?;
188+
189+
// Second batch: nullable val with NULLs (simulates table UNION branch)
190+
let nullable_batch = RecordBatch::try_new(
191+
Arc::clone(&nullable_schema),
192+
vec![
193+
Arc::new(Int64Array::from(vec![4, 5, 6])),
194+
Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])),
195+
],
196+
)?;
197+
in_progress.append_batch(&nullable_batch)?;
198+
199+
let spill_file = in_progress.finish()?.unwrap();
200+
201+
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
202+
203+
// Stream schema should be nullable
204+
assert_eq!(stream.schema(), nullable_schema);
205+
206+
let batches = stream.try_collect::<Vec<_>>().await?;
207+
assert_eq!(batches.len(), 2);
208+
209+
// Both batches must have the SpillManager's nullable schema
210+
assert_eq!(
211+
batches[0],
212+
non_nullable_batch.with_schema(Arc::clone(&nullable_schema))?
213+
);
214+
assert_eq!(batches[1], nullable_batch);
215+
216+
Ok(())
217+
}
218+
}

0 commit comments

Comments
 (0)