From 25c357e4656e49965741b1865d55f906ccf005f5 Mon Sep 17 00:00:00 2001 From: nuno-faria Date: Sat, 23 May 2026 17:43:48 +0100 Subject: [PATCH 1/3] fix: Optimize projections in recursive CTEs --- .../optimizer/src/optimize_projections/mod.rs | 87 +---------- .../optimizer/tests/optimizer_integration.rs | 83 ++++------ datafusion/sqllogictest/test_files/cte.slt | 142 +++++++++++++++--- 3 files changed, 157 insertions(+), 155 deletions(-) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 59109a822bdbe..f0aa002ac9bde 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -381,29 +381,12 @@ fn optimize_projections( // These operators have no inputs, so stop the optimization process. return Ok(Transformed::no(plan)); } - LogicalPlan::RecursiveQuery(recursive) => { - // Only allow subqueries that reference the current CTE; nested subqueries are not yet - // supported for projection pushdown for simplicity. - // TODO: be able to do projection pushdown on recursive CTEs with subqueries - if plan_contains_other_subqueries( - recursive.static_term.as_ref(), - &recursive.name, - ) || plan_contains_other_subqueries( - recursive.recursive_term.as_ref(), - &recursive.name, - ) { - return Ok(Transformed::no(plan)); - } - - plan.inputs() - .into_iter() - .map(|input| { - indices - .clone() - .with_projection_beneficial() - .with_plan_exprs(&plan, input.schema()) - }) - .collect::>>()? + LogicalPlan::RecursiveQuery(_) => { + // optimize the static and recursive terms + return plan.map_children(|c| { + let indices = RequiredIndices::new_for_all_exprs(&c); + optimize_projections(c, config, indices) + }); } LogicalPlan::Join(join) => { let left_len = join.left.schema().fields().len(); @@ -900,64 +883,6 @@ pub fn is_projection_unnecessary( )) } -/// Returns true if the plan subtree contains any subqueries that are not the -/// CTE reference itself. This treats any non-CTE [`LogicalPlan::SubqueryAlias`] -/// node (including aliased relations) as a blocker, along with expression-level -/// subqueries like scalar, EXISTS, or IN. These cases prevent projection -/// pushdown for now because we cannot safely reason about their column usage. -fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool { - if let LogicalPlan::SubqueryAlias(alias) = plan - && alias.alias.table() != cte_name - && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) - { - return true; - } - - let mut found = false; - plan.apply_expressions(|expr| { - if expr_contains_subquery(expr) { - found = true; - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }) - .expect("expression traversal never fails"); - if found { - return true; - } - - plan.inputs() - .into_iter() - .any(|child| plan_contains_other_subqueries(child, cte_name)) -} - -fn expr_contains_subquery(expr: &Expr) -> bool { - expr.exists(|e| match e { - Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true), - _ => Ok(false), - }) - // Safe unwrap since we are doing a simple boolean check - .unwrap() -} - -fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool { - match plan { - LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name, - LogicalPlan::SubqueryAlias(alias) => { - subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) - } - _ => { - let inputs = plan.inputs(); - if inputs.len() == 1 { - subquery_alias_targets_recursive_cte(inputs[0], cte_name) - } else { - false - } - } - } -} - #[cfg(test)] mod tests { use std::cmp::Ordering; diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 4e33bf6b3abcc..84c0166364b98 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -56,8 +56,7 @@ fn init() { #[test] fn recursive_cte_with_nested_subquery() -> Result<()> { - // Covers bailout path in `plan_contains_other_subqueries`, ensuring nested subqueries - // within recursive CTE branches prevent projection pushdown. + // projection optimization is applied to recursive CTEs even with nested subqueries let sql = r#" WITH RECURSIVE numbers(id, level) AS ( SELECT sub.id, sub.level FROM ( @@ -79,17 +78,16 @@ fn recursive_cte_with_nested_subquery() -> Result<()> { SubqueryAlias: numbers Projection: sub.id AS id, sub.level AS level RecursiveQuery: is_distinct=false - Projection: sub.id, sub.level - SubqueryAlias: sub - Projection: test.col_int32 AS id, Int64(1) AS level - TableScan: test + SubqueryAlias: sub + Projection: test.col_int32 AS id, Int64(1) AS level + TableScan: test projection=[col_int32] Projection: t.col_int32, numbers.level + Int64(1) Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1) SubqueryAlias: t Filter: CAST(test.col_int32 AS Int64) IS NOT NULL - TableScan: test + TableScan: test projection=[col_int32] Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL - TableScan: numbers + TableScan: numbers projection=[id, level] " ); @@ -527,12 +525,10 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() { " ); } - #[test] -fn recursive_cte_projection_pushdown() -> Result<()> { - // Test that projection pushdown works with recursive CTEs by ensuring - // only the required columns are projected from the base table, even when - // the CTE definition includes unused columns +fn recursive_cte_outer_projection_pushdown() -> Result<()> { + // projection optimization of a recursive CTE based on the outer query's projected columns is + // not done as this can lead to bugs (see: https://github.com/apache/datafusion/issues/22249). let sql = "WITH RECURSIVE nodes AS (\ SELECT col_int32 AS id, col_utf8 AS name, col_uint32 AS extra FROM test \ UNION ALL \ @@ -540,18 +536,19 @@ fn recursive_cte_projection_pushdown() -> Result<()> { ) SELECT id FROM nodes"; let plan = test_sql(sql)?; - // The optimizer successfully performs projection pushdown by only selecting the needed - // columns from the base table and recursive table, eliminating unused columns + // col_int32, col_utf8, and col_uint32 and projected from test since they are used in the + // recursive CTE, even though the outer query only requires col_int32 assert_snapshot!( format!("{plan}"), @r" SubqueryAlias: nodes - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS id - TableScan: test projection=[col_int32] - Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) - Filter: nodes.id < Int32(3) - TableScan: nodes projection=[id] + Projection: id + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id, test.col_utf8 AS name, test.col_uint32 AS extra + TableScan: test projection=[col_int32, col_uint32, col_utf8] + Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32), nodes.name, nodes.extra + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id, name, extra] " ); Ok(()) @@ -570,47 +567,19 @@ fn recursive_cte_with_aliased_self_reference() -> Result<()> { format!("{plan}"), @r" SubqueryAlias: nodes - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS id - TableScan: test projection=[col_int32] - Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) - SubqueryAlias: child - Filter: nodes.id < Int32(3) - TableScan: nodes projection=[id] + Projection: id + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id, test.col_utf8 AS name + TableScan: test projection=[col_int32, col_utf8] + Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32), child.name + SubqueryAlias: child + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id, name] ", ); Ok(()) } -#[test] -fn recursive_cte_with_unused_columns() -> Result<()> { - // Test projection pushdown with a recursive CTE where the base case - // includes columns that are never used in the recursive part or final result - let sql = "WITH RECURSIVE series AS (\ - SELECT 1 AS n, col_utf8, col_uint32, col_date32 FROM test WHERE col_int32 = 1 \ - UNION ALL \ - SELECT n + 1, col_utf8, col_uint32, col_date32 FROM series WHERE n < 3\ - ) SELECT n FROM series"; - let plan = test_sql(sql)?; - - // The optimizer successfully performs projection pushdown by eliminating unused columns - // even when they're defined in the CTE but not actually needed - assert_snapshot!( - format!("{plan}"), - @r" - SubqueryAlias: series - RecursiveQuery: is_distinct=false - Projection: Int64(1) AS n - Filter: test.col_int32 = Int32(1) - TableScan: test projection=[col_int32] - Projection: series.n + Int64(1) - Filter: series.n < Int64(3) - TableScan: series projection=[n] - " - ); - Ok(()) -} - #[test] /// Asserts the minimal plan shape once projection pushdown succeeds for a recursive CTE. /// Unlike the previous two tests that retain extra columns in either the base or recursive diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index d13e0d4f085e9..3245947a15217 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -842,12 +842,12 @@ logical_plan 03)----Projection: Int64(1) AS val 04)------EmptyRelation: rows=1 05)----Projection: Int64(2) AS val -06)------Cross Join: -07)--------Filter: recursive_cte.val < Int64(2) -08)----------TableScan: recursive_cte -09)--------SubqueryAlias: sub_cte -10)----------Projection: Int64(2) AS val -11)------------EmptyRelation: rows=1 +06)------Cross Join: +07)--------Projection: +08)----------Filter: recursive_cte.val < Int64(2) +09)------------TableScan: recursive_cte projection=[val] +10)--------SubqueryAlias: sub_cte +11)----------EmptyRelation: rows=1 physical_plan 01)RecursiveQueryExec: name=recursive_cte, is_distinct=false 02)--ProjectionExec: expr=[1 as val] @@ -855,11 +855,10 @@ physical_plan 04)--ProjectionExec: expr=[2 as val] 05)----CrossJoinExec 06)------CoalescePartitionsExec -07)--------FilterExec: val@0 < 2 +07)--------FilterExec: val@0 < 2, projection=[] 08)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)------------WorkTableExec: name=recursive_cte -10)------ProjectionExec: expr=[2 as val] -11)--------PlaceholderRowExec +10)------PlaceholderRowExec # Test issue: https://github.com/apache/datafusion/issues/9794 # Non-recursive term and recursive term have different types @@ -1205,14 +1204,13 @@ EXPLAIN WITH RECURSIVE trans AS ( logical_plan 01)SubqueryAlias: trans 02)--RecursiveQuery: is_distinct=true -03)----Projection: closure.start, closure.end -04)------TableScan: closure -05)----Projection: l.start, r.end -06)------Inner Join: l.end = r.start -07)--------SubqueryAlias: l -08)----------TableScan: trans -09)--------SubqueryAlias: r -10)----------TableScan: closure +03)----TableScan: closure projection=[start, end] +04)----Projection: l.start, r.end +05)------Inner Join: l.end = r.start +06)--------SubqueryAlias: l +07)----------TableScan: trans projection=[start, end] +08)--------SubqueryAlias: r +09)----------TableScan: closure projection=[start, end] physical_plan 01)RecursiveQueryExec: name=trans, is_distinct=true 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true @@ -1319,3 +1317,113 @@ RESET datafusion.execution.enable_recursive_ctes; statement ok RESET datafusion.sql_parser.enable_ident_normalization; + + +# Test projection optimization in recursive CTEs + +# https://github.com/apache/datafusion/issues/22249 +query I +with recursive t(k, v) as ( + select 1 k, 10 v + union all + select 2, 20 from t where k = 1 +) +select v +from t +order by 1; +---- +10 +20 + +# https://github.com/apache/datafusion/issues/22249 +query I +with recursive t(k, v) as ( + select 1 k, 10 v + union all + select 2, 20 from t where v = 10 +) +select v +from t +order by 1; +---- +10 +20 + +statement ok +copy ( + select i as k, i as v1, i as v2 + from generate_series(1, 3) t(i) +) to 'test_files/scratch/cte/test.parquet'; + +statement ok +create external table test stored as parquet location 'test_files/scratch/cte/test.parquet'; + +# check that both the static and recursive terms are optimized +query TT +explain +with recursive r as ( + select k, v1 -- only needs to project k and v1 from table test + from test + union all + select k * 10, v1 + from r + where k < ( -- only needs to project k and v2 from table test + select v2 + from test + where k = 2 + ) +) +select * +from r +order by 1, 2; +---- +logical_plan +01)Sort: r.k ASC NULLS LAST, r.v1 ASC NULLS LAST +02)--SubqueryAlias: r +03)----RecursiveQuery: is_distinct=false +04)------TableScan: test projection=[k, v1] +05)------Projection: r.k * Int64(10), r.v1 +06)--------Filter: r.k < () +07)----------Subquery: +08)------------Projection: test.v2 +09)--------------Filter: test.k = Int64(2) +10)----------------TableScan: test projection=[k, v2], partial_filters=[test.k = Int64(2)] +11)----------TableScan: r projection=[k, v1] +physical_plan +01)ScalarSubqueryExec: subqueries=1 +02)--SortExec: expr=[k@0 ASC NULLS LAST, v1@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----RecursiveQueryExec: name=r, is_distinct=false +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/cte/test.parquet]]}, projection=[k, v1], output_ordering=[k@0 ASC NULLS LAST], file_type=parquet +05)------CoalescePartitionsExec +06)--------ProjectionExec: expr=[k@0 * 10 as k, v1@1 as v1] +07)----------FilterExec: k@0 < scalar_subquery() +08)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------------WorkTableExec: name=r +10)--FilterExec: k@0 = 2, projection=[v2@1] +11)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/cte/test.parquet]]}, projection=[k, v2], output_ordering=[k@0 ASC NULLS LAST], file_type=parquet, predicate=k@0 = 2, pruning_predicate=k_null_count@2 != row_count@3 AND k_min@0 <= 2 AND 2 <= k_max@1, required_guarantees=[k in (2)] + +query II +with recursive r as ( + select k, v1 + from test + union all + select k * 10, v1 + from r + where k < ( + select v2 + from test + where k = 2 + ) +) +select * +from r +order by 1, 2; +---- +1 1 +2 2 +3 3 +10 1 + +statement ok +drop table test; From 60ee5c5b7b1687b8069d10a7706b20f193b161f1 Mon Sep 17 00:00:00 2001 From: nuno-faria Date: Sat, 23 May 2026 18:30:01 +0100 Subject: [PATCH 2/3] Fix clippy --- datafusion/optimizer/src/optimize_projections/mod.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index f0aa002ac9bde..7e492fc1e4c94 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -35,9 +35,7 @@ use datafusion_expr::{ use crate::optimize_projections::required_indices::RequiredIndices; use crate::utils::NamePreserver; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeContainer}; /// Optimizer rule to prune unnecessary columns from intermediate schemas /// inside the [`LogicalPlan`]. This rule: From 7007847642fd05f4b83426927224642182d70bec Mon Sep 17 00:00:00 2001 From: nuno-faria Date: Mon, 25 May 2026 08:34:29 +0100 Subject: [PATCH 3/3] Add union test Co-authored-by: Bruce Ritchie --- datafusion/sqllogictest/test_files/cte.slt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 3245947a15217..7631f6f4d5503 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -1349,6 +1349,21 @@ order by 1; 10 20 +# Keep columns that are not selected by the outer query, but still affect +# recursive UNION distinctness. +query I +with recursive t(k, v) as ( + select 1 k, 10 v + union + select 2, 10 from t where v = 10 +) +select v +from t +order by 1; +---- +10 +10 + statement ok copy ( select i as k, i as v1, i as v2