Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 6 additions & 81 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,29 +381,12 @@ fn optimize_projections(
// These operators have no inputs, so stop the optimization process.
return Ok(Transformed::no(plan));
}
LogicalPlan::RecursiveQuery(recursive) => {
// Only allow subqueries that reference the current CTE; nested subqueries are not yet
// supported for projection pushdown for simplicity.
// TODO: be able to do projection pushdown on recursive CTEs with subqueries
if plan_contains_other_subqueries(
recursive.static_term.as_ref(),
&recursive.name,
) || plan_contains_other_subqueries(
recursive.recursive_term.as_ref(),
&recursive.name,
) {
return Ok(Transformed::no(plan));
}

plan.inputs()
.into_iter()
.map(|input| {
indices
.clone()
.with_projection_beneficial()
.with_plan_exprs(&plan, input.schema())
})
.collect::<Result<Vec<_>>>()?
LogicalPlan::RecursiveQuery(_) => {
// optimize the static and recursive terms
return plan.map_children(|c| {
let indices = RequiredIndices::new_for_all_exprs(&c);
optimize_projections(c, config, indices)
});
}
LogicalPlan::Join(join) => {
let left_len = join.left.schema().fields().len();
Expand Down Expand Up @@ -900,64 +883,6 @@ pub fn is_projection_unnecessary(
))
}

/// Returns true if the plan subtree contains any subqueries that are not the
/// CTE reference itself. This treats any non-CTE [`LogicalPlan::SubqueryAlias`]
/// node (including aliased relations) as a blocker, along with expression-level
/// subqueries like scalar, EXISTS, or IN. These cases prevent projection
/// pushdown for now because we cannot safely reason about their column usage.
fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
if let LogicalPlan::SubqueryAlias(alias) = plan
&& alias.alias.table() != cte_name
&& !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
{
return true;
}

let mut found = false;
plan.apply_expressions(|expr| {
if expr_contains_subquery(expr) {
found = true;
Ok(TreeNodeRecursion::Stop)
} else {
Ok(TreeNodeRecursion::Continue)
}
})
.expect("expression traversal never fails");
if found {
return true;
}

plan.inputs()
.into_iter()
.any(|child| plan_contains_other_subqueries(child, cte_name))
}

fn expr_contains_subquery(expr: &Expr) -> bool {
expr.exists(|e| match e {
Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
_ => Ok(false),
})
// Safe unwrap since we are doing a simple boolean check
.unwrap()
}

fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool {
match plan {
LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name,
LogicalPlan::SubqueryAlias(alias) => {
subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name)
}
_ => {
let inputs = plan.inputs();
if inputs.len() == 1 {
subquery_alias_targets_recursive_cte(inputs[0], cte_name)
} else {
false
}
}
}
}

#[cfg(test)]
mod tests {
use std::cmp::Ordering;
Expand Down
83 changes: 26 additions & 57 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ fn init() {

#[test]
fn recursive_cte_with_nested_subquery() -> Result<()> {
// Covers bailout path in `plan_contains_other_subqueries`, ensuring nested subqueries
// within recursive CTE branches prevent projection pushdown.
// projection optimization is applied to recursive CTEs even with nested subqueries
let sql = r#"
WITH RECURSIVE numbers(id, level) AS (
SELECT sub.id, sub.level FROM (
Expand All @@ -79,17 +78,16 @@ fn recursive_cte_with_nested_subquery() -> Result<()> {
SubqueryAlias: numbers
Projection: sub.id AS id, sub.level AS level
RecursiveQuery: is_distinct=false
Projection: sub.id, sub.level
SubqueryAlias: sub
Projection: test.col_int32 AS id, Int64(1) AS level
TableScan: test
SubqueryAlias: sub
Projection: test.col_int32 AS id, Int64(1) AS level
TableScan: test projection=[col_int32]
Projection: t.col_int32, numbers.level + Int64(1)
Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1)
SubqueryAlias: t
Filter: CAST(test.col_int32 AS Int64) IS NOT NULL
TableScan: test
TableScan: test projection=[col_int32]
Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL
TableScan: numbers
TableScan: numbers projection=[id, level]
"
);

Expand Down Expand Up @@ -527,31 +525,30 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() {
"
);
}

#[test]
fn recursive_cte_projection_pushdown() -> Result<()> {
// Test that projection pushdown works with recursive CTEs by ensuring
// only the required columns are projected from the base table, even when
// the CTE definition includes unused columns
fn recursive_cte_outer_projection_pushdown() -> Result<()> {
// projection optimization of a recursive CTE based on the outer query's projected columns is
// not done as this can lead to bugs (see: https://github.com/apache/datafusion/issues/22249).
let sql = "WITH RECURSIVE nodes AS (\
SELECT col_int32 AS id, col_utf8 AS name, col_uint32 AS extra FROM test \
UNION ALL \
SELECT id + 1, name, extra FROM nodes WHERE id < 3\
) SELECT id FROM nodes";
let plan = test_sql(sql)?;

// The optimizer successfully performs projection pushdown by only selecting the needed
// columns from the base table and recursive table, eliminating unused columns
// col_int32, col_utf8, and col_uint32 and projected from test since they are used in the
// recursive CTE, even though the outer query only requires col_int32
assert_snapshot!(
format!("{plan}"),
@r"
SubqueryAlias: nodes
RecursiveQuery: is_distinct=false
Projection: test.col_int32 AS id
TableScan: test projection=[col_int32]
Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32)
Filter: nodes.id < Int32(3)
TableScan: nodes projection=[id]
Projection: id
RecursiveQuery: is_distinct=false
Projection: test.col_int32 AS id, test.col_utf8 AS name, test.col_uint32 AS extra
TableScan: test projection=[col_int32, col_uint32, col_utf8]
Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32), nodes.name, nodes.extra
Filter: nodes.id < Int32(3)
TableScan: nodes projection=[id, name, extra]
"
);
Ok(())
Expand All @@ -570,47 +567,19 @@ fn recursive_cte_with_aliased_self_reference() -> Result<()> {
format!("{plan}"),
@r"
SubqueryAlias: nodes
RecursiveQuery: is_distinct=false
Projection: test.col_int32 AS id
TableScan: test projection=[col_int32]
Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32)
SubqueryAlias: child
Filter: nodes.id < Int32(3)
TableScan: nodes projection=[id]
Projection: id
RecursiveQuery: is_distinct=false
Projection: test.col_int32 AS id, test.col_utf8 AS name
TableScan: test projection=[col_int32, col_utf8]
Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32), child.name
SubqueryAlias: child
Filter: nodes.id < Int32(3)
TableScan: nodes projection=[id, name]
",
);
Ok(())
}

#[test]
fn recursive_cte_with_unused_columns() -> Result<()> {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was very similar to the now named recursive_cte_outer_projection_pushdown.

// Test projection pushdown with a recursive CTE where the base case
// includes columns that are never used in the recursive part or final result
let sql = "WITH RECURSIVE series AS (\
SELECT 1 AS n, col_utf8, col_uint32, col_date32 FROM test WHERE col_int32 = 1 \
UNION ALL \
SELECT n + 1, col_utf8, col_uint32, col_date32 FROM series WHERE n < 3\
) SELECT n FROM series";
let plan = test_sql(sql)?;

// The optimizer successfully performs projection pushdown by eliminating unused columns
// even when they're defined in the CTE but not actually needed
assert_snapshot!(
format!("{plan}"),
@r"
SubqueryAlias: series
RecursiveQuery: is_distinct=false
Projection: Int64(1) AS n
Filter: test.col_int32 = Int32(1)
TableScan: test projection=[col_int32]
Projection: series.n + Int64(1)
Filter: series.n < Int64(3)
TableScan: series projection=[n]
"
);
Ok(())
}

#[test]
/// Asserts the minimal plan shape once projection pushdown succeeds for a recursive CTE.
/// Unlike the previous two tests that retain extra columns in either the base or recursive
Expand Down
Loading
Loading