Skip to content

Commit 870e448

Browse files
committed
feat: update recursive CTE and schema alignment for nullability
- Reverted oom_recursive_cte query to use SELECT 1 as id. - Updated schema alignment to permit safe nullability widening (non-null to nullable) through same-type CastExpr. - Enhanced recursive CTE to reconcile output nullability across static and recursive terms, ensuring alignment of both plans. - Rejected nullability narrowing.
1 parent 08b5068 commit 870e448

3 files changed

Lines changed: 77 additions & 21 deletions

File tree

datafusion/core/tests/memory_limit/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ async fn oom_recursive_cte() {
343343
TestCase::new()
344344
.with_query(
345345
"WITH RECURSIVE nodes AS (
346-
SELECT id FROM (VALUES (1), (NULL)) AS t(id) WHERE id IS NOT NULL
346+
SELECT 1 as id
347347
UNION ALL
348348
SELECT UNNEST(RANGE(id+1, id+1000)) as id
349349
FROM nodes

datafusion/physical-plan/src/common.rs

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::fs::metadata;
2222
use std::sync::Arc;
2323

2424
use super::SendableRecordBatchStream;
25-
use crate::expressions::Column;
25+
use crate::expressions::{CastExpr, Column};
2626
use crate::projection::{ProjectionExec, ProjectionExpr};
2727
use crate::stream::RecordBatchReceiverStream;
2828
use crate::{ColumnStatistics, ExecutionPlan, Statistics};
@@ -91,18 +91,18 @@ fn build_file_list_recurse(
9191
Ok(())
9292
}
9393

94-
/// Align `input`'s physical plan schema with `expected_schema` when only field names differ.
94+
/// Align `input`'s physical plan schema with `expected_schema`.
9595
///
9696
/// This helper is intended for operators that combine independently planned children but
9797
/// expose a single declared output schema. It returns `input` unchanged when schemas already
9898
/// match exactly. Otherwise, it validates that projection can safely produce the expected
9999
/// schema, then wraps `input` in a [`ProjectionExec`] that keeps columns in their existing
100100
/// positional order and aliases them to `expected_schema`'s field names.
101101
///
102-
/// [`ProjectionExec`] can rename fields but preserves column data types, nullability, field
103-
/// metadata, and schema metadata from the input expressions. Therefore, this helper rejects
104-
/// mismatches in those attributes rather than returning a plan whose schema still differs
105-
/// from `expected_schema`.
102+
/// [`ProjectionExec`] can rename fields. When the expected field is nullable and the input
103+
/// field is not, this helper also widens nullability with a same-type [`CastExpr`]. It rejects
104+
/// differences that projection cannot safely normalize exactly, such as data type, metadata,
105+
/// schema metadata, and nullability narrowing.
106106
pub fn project_plan_to_schema(
107107
input: Arc<dyn ExecutionPlan>,
108108
expected_schema: &SchemaRef,
@@ -134,7 +134,7 @@ pub fn project_plan_to_schema(
134134
.find_map(|(i, (input_field, expected_field))| {
135135
if input_field.data_type() != expected_field.data_type() {
136136
Some((i, input_field, expected_field, "data type"))
137-
} else if input_field.is_nullable() != expected_field.is_nullable() {
137+
} else if input_field.is_nullable() && !expected_field.is_nullable() {
138138
Some((i, input_field, expected_field, "nullability"))
139139
} else if input_field.metadata() != expected_field.metadata() {
140140
Some((i, input_field, expected_field, "metadata"))
@@ -157,9 +157,22 @@ pub fn project_plan_to_schema(
157157
.fields()
158158
.iter()
159159
.enumerate()
160-
.map(|(i, expected_field)| ProjectionExpr {
161-
expr: Arc::new(Column::new(input_schema.field(i).name(), i)),
162-
alias: expected_field.name().clone(),
160+
.map(|(i, expected_field)| {
161+
let input_field = input_schema.field(i);
162+
let column = Arc::new(Column::new(input_field.name(), i));
163+
let expr = if !input_field.is_nullable() && expected_field.is_nullable() {
164+
Arc::new(CastExpr::new_with_target_field(
165+
column,
166+
Arc::clone(expected_field),
167+
None,
168+
)) as _
169+
} else {
170+
column as _
171+
};
172+
ProjectionExpr {
173+
expr,
174+
alias: expected_field.name().clone(),
175+
}
163176
})
164177
.collect::<Vec<_>>();
165178

@@ -484,7 +497,22 @@ mod tests {
484497
}
485498

486499
#[test]
487-
fn project_plan_to_schema_errors_on_nullability_mismatch() {
500+
fn project_plan_to_schema_widens_nullability() -> Result<()> {
501+
let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]);
502+
let expected_schema = Arc::new(Schema::new(vec![Field::new(
503+
"renamed",
504+
DataType::Int32,
505+
true,
506+
)]));
507+
508+
let result = project_plan_to_schema(input, &expected_schema)?;
509+
510+
assert_eq!(result.schema(), expected_schema);
511+
Ok(())
512+
}
513+
514+
#[test]
515+
fn project_plan_to_schema_errors_on_nullability_narrowing() {
488516
let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]);
489517
let expected_schema = Arc::new(Schema::new(vec![Field::new(
490518
"renamed",

datafusion/physical-plan/src/recursive_query.rs

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use crate::{
3535
};
3636
use arrow::array::{BooleanArray, BooleanBuilder};
3737
use arrow::compute::filter_record_batch;
38-
use arrow::datatypes::SchemaRef;
38+
use arrow::datatypes::{Field, Schema, SchemaRef};
3939
use arrow::record_batch::RecordBatch;
4040
use datafusion_common::tree_node::TreeNodeRecursion;
4141
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
@@ -91,10 +91,12 @@ impl RecursiveQueryExec {
9191
// Each recursive query needs its own work table
9292
let work_table = Arc::new(WorkTable::new(name.clone()));
9393
// Use the same work table for both the WorkTableExec and the recursive term
94+
let output_schema =
95+
recursive_output_schema(&static_term.schema(), &recursive_term.schema());
96+
let static_term = project_plan_to_schema(static_term, &output_schema)?;
9497
let recursive_term = assign_work_table(recursive_term, &work_table)?;
95-
let recursive_term =
96-
project_plan_to_schema(recursive_term, &static_term.schema())?;
97-
let cache = Self::compute_properties(static_term.schema());
98+
let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?;
99+
let cache = Self::compute_properties(output_schema);
98100
Ok(RecursiveQueryExec {
99101
name,
100102
static_term,
@@ -368,6 +370,30 @@ impl RecursiveQueryStream {
368370
}
369371
}
370372

373+
fn recursive_output_schema(
374+
static_schema: &SchemaRef,
375+
recursive_schema: &SchemaRef,
376+
) -> SchemaRef {
377+
let fields = static_schema
378+
.fields()
379+
.iter()
380+
.zip(recursive_schema.fields())
381+
.map(|(static_field, recursive_field)| {
382+
Field::new(
383+
static_field.name(),
384+
static_field.data_type().clone(),
385+
static_field.is_nullable() || recursive_field.is_nullable(),
386+
)
387+
.with_metadata(static_field.metadata().clone())
388+
})
389+
.collect::<Vec<_>>();
390+
391+
Arc::new(Schema::new_with_metadata(
392+
fields,
393+
static_schema.metadata().clone(),
394+
))
395+
}
396+
371397
fn assign_work_table(
372398
plan: Arc<dyn ExecutionPlan>,
373399
work_table: &Arc<WorkTable>,
@@ -528,19 +554,21 @@ mod tests {
528554
}
529555

530556
#[test]
531-
fn recursive_query_exec_rejects_nullability_mismatch() {
557+
fn recursive_query_exec_reconciles_nullability() -> Result<()> {
532558
let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]);
533559
let recursive_term =
534560
empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]);
535561

536-
let err = RecursiveQueryExec::try_new(
562+
let exec = RecursiveQueryExec::try_new(
537563
"numbers".to_string(),
538564
static_term,
539565
recursive_term,
540566
false,
541-
)
542-
.unwrap_err();
567+
)?;
543568

544-
assert!(err.to_string().contains("field nullability differs"));
569+
assert!(exec.schema().field(0).is_nullable());
570+
assert!(exec.static_term().schema().field(0).is_nullable());
571+
assert!(exec.recursive_term().schema().field(0).is_nullable());
572+
Ok(())
545573
}
546574
}

0 commit comments

Comments
 (0)