From 4ee03d6193ce0edc1eb73952d7af9ef2f3975b33 Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Sat, 4 Apr 2026 13:20:00 -0400 Subject: [PATCH 1/5] feat: support InSubquery and Exists in Projection expressions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, `DecorrelatePredicateSubquery` only handled InSubquery and Exists nodes in Filter (WHERE/HAVING). When they appeared in Projection expressions (CASE WHEN, COALESCE, bare boolean SELECT), the physical planner errored with "Physical plan does not support logical expression InSubquery(...)". This adds a Projection match arm following the same pattern as ScalarSubqueryToJoin's Projection handler. InSubquery and Exists in Projection expressions are decorrelated into LeftMark joins, with the mark column aliased to preserve the original schema. When decorrelation fails (e.g., correlated subquery with LIMIT), the plan is returned unchanged — matching ScalarSubqueryToJoin's bail-out semantics. Patterns now supported: - CASE WHEN id IN (SELECT ...) THEN ... ELSE ... END - CASE WHEN EXISTS (SELECT ...) THEN ... END - NOT IN / NOT EXISTS in CASE - id IN (SELECT ...) as bare boolean in SELECT - Correlated IN/EXISTS in expressions - Multiple subqueries in one expression - COALESCE with IN subquery Includes sqllogictest coverage (8 test cases) and plan-shape unit test for the bail-out path. --- .../src/decorrelate_predicate_subquery.rs | 387 ++++++++++++++++-- .../test_files/in_subquery_projection.slt | 148 +++++++ 2 files changed, 495 insertions(+), 40 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/in_subquery_projection.slt diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index a4c5d8c38549d..83c6477c49c61 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -35,6 +35,7 @@ use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned}; +use datafusion_expr::logical_plan::Projection; use datafusion_expr::{ BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, exists, in_subquery, lit, not, not_exists, not_in_subquery, @@ -69,53 +70,113 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - let LogicalPlan::Filter(filter) = plan else { - return Ok(Transformed::no(plan)); - }; - - if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); - } + match plan { + LogicalPlan::Filter(filter) => { + if !has_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } - let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = - split_conjunction_owned(filter.predicate) - .into_iter() - .partition(has_subquery); + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); + + assert_or_internal_err!( + !with_subqueries.is_empty(), + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); + + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top( + &subquery, + &cur_input, + config.alias_generator(), + )? { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } + } + } - assert_or_internal_err!( - !with_subqueries.is_empty(), - "can not find expected subqueries in DecorrelatePredicateSubquery" - ); + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + cur_input = LogicalPlan::Filter(new_filter); + } + Ok(Transformed::yes(cur_input)) + } + LogicalPlan::Projection(projection) => { + // Skip if no predicate subqueries in any projection expression + if !projection.expr.iter().any(has_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(projection))); + } - // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(filter.input); - for subquery_expr in with_subqueries { - match extract_subquery_info(subquery_expr) { - // The subquery expression is at the top level of the filter - SubqueryPredicate::Top(subquery) => { - match build_join_top(&subquery, &cur_input, config.alias_generator())? - { - Some(plan) => cur_input = plan, - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - None => other_exprs.push(subquery.expr()), + // Keep an Arc clone of the original input so we can reconstruct + // the Projection if decorrelation fails for any expression. + let original_input = Arc::clone(&projection.input); + let mut cur_input = Arc::unwrap_or_clone(projection.input); + let mut new_exprs = Vec::with_capacity(projection.expr.len()); + + for expr in &projection.expr { + if has_subquery(expr) { + let (plan, rewritten) = + rewrite_inner_subqueries(cur_input, expr.clone(), config)?; + cur_input = plan; + new_exprs.push(rewritten); + } else { + new_exprs.push(expr.clone()); } } - // The subquery expression is embedded within another expression - SubqueryPredicate::Embedded(expr) => { - let (plan, expr_without_subqueries) = - rewrite_inner_subqueries(cur_input, expr, config)?; - cur_input = plan; - other_exprs.push(expr_without_subqueries); + + // If any expression still contains a subquery after rewriting, + // decorrelation failed — bail out and return the original plan + // unchanged (same pattern as ScalarSubqueryToJoin). + if new_exprs.iter().any(has_subquery) { + let original = Projection::try_new_with_schema( + projection.expr, + original_input, + projection.schema, + )?; + return Ok(Transformed::no(LogicalPlan::Projection(original))); } - } - } - let expr = conjunction(other_exprs); - if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); + // Preserve original column names via aliases where the rewrite changed them + let proj_exprs: Vec = projection + .expr + .iter() + .zip(new_exprs) + .map(|(old, new)| { + let old_name = old.schema_name().to_string(); + let new_name = new.schema_name().to_string(); + if old_name != new_name { + new.alias(old_name) + } else { + new + } + }) + .collect(); + + let new_plan = LogicalPlanBuilder::from(cur_input) + .project(proj_exprs)? + .build()?; + Ok(Transformed::yes(new_plan)) + } + plan => Ok(Transformed::no(plan)), } - Ok(Transformed::yes(cur_input)) } fn name(&self) -> &str { @@ -538,7 +599,9 @@ mod tests { use crate::assert_optimized_plan_eq_display_indent_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::builder::table_source; - use datafusion_expr::{and, binary_expr, col, out_ref_col, table_scan}; + use datafusion_expr::{ + and, binary_expr, col, not_in_subquery, out_ref_col, table_scan, when, + }; macro_rules! assert_optimized_plan_equal { ( @@ -2114,4 +2177,248 @@ mod tests { " ) } + + // ----------------------------------------------------------------------- + // Tests for InSubquery / Exists in Projection expressions + // ----------------------------------------------------------------------- + + /// IN subquery inside CASE WHEN in a projection expression + #[test] + fn in_subquery_in_case_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let sq = test_subquery_with_name("sq")?; + + let case_expr = + when(in_subquery(col("c"), sq), lit("yes")).otherwise(lit("no"))?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), case_expr])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a, CASE WHEN __correlated_sq_1.mark THEN Utf8("yes") ELSE Utf8("no") END AS CASE WHEN IN THEN Utf8("yes") ELSE Utf8("no") END [a:UInt32, CASE WHEN IN THEN Utf8("yes") ELSE Utf8("no") END:Utf8] + LeftMark Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: __correlated_sq_1.c [c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + "# + ) + } + + /// EXISTS subquery inside CASE WHEN in a projection expression + #[test] + fn exists_in_case_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("orders.o_custkey") + .eq(out_ref_col(DataType::UInt32, "test.a")), + )? + .project(vec![lit(1)])? + .build()?, + ); + + let case_expr = + when(exists(sq), lit("has_orders")).otherwise(lit("no_orders"))?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), case_expr])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a, CASE WHEN __correlated_sq_1.mark THEN Utf8("has_orders") ELSE Utf8("no_orders") END AS CASE WHEN EXISTS THEN Utf8("has_orders") ELSE Utf8("no_orders") END [a:UInt32, CASE WHEN EXISTS THEN Utf8("has_orders") ELSE Utf8("no_orders") END:Utf8] + LeftMark Join: Filter: __correlated_sq_1.o_custkey = test.a [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: __correlated_sq_1.o_custkey [o_custkey:Int64] + SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, o_custkey:Int64] + Projection: Int32(1), orders.o_custkey [Int32(1):Int32, o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "# + ) + } + + /// NOT IN subquery inside CASE WHEN in a projection expression + #[test] + fn not_in_subquery_in_case_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let sq = test_subquery_with_name("sq")?; + + let case_expr = when(not_in_subquery(col("c"), sq), lit("excluded")) + .otherwise(lit("included"))?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), case_expr])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a, CASE WHEN NOT __correlated_sq_1.mark THEN Utf8("excluded") ELSE Utf8("included") END AS CASE WHEN NOT IN THEN Utf8("excluded") ELSE Utf8("included") END [a:UInt32, CASE WHEN NOT IN THEN Utf8("excluded") ELSE Utf8("included") END:Utf8] + LeftMark Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: __correlated_sq_1.c [c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + "# + ) + } + + /// IN subquery as bare boolean in SELECT (no CASE wrapper) + #[test] + fn in_subquery_bare_bool_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let sq = test_subquery_with_name("sq")?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), in_subquery(col("c"), sq)])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, __correlated_sq_1.mark AS IN [a:UInt32, IN:Boolean] + LeftMark Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: __correlated_sq_1.c [c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) + } + + /// Correlated IN subquery inside CASE WHEN in a projection expression + #[test] + fn correlated_in_subquery_in_case_projection() -> Result<()> { + let orders = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + col("orders.o_custkey") + .eq(out_ref_col(DataType::Int64, "customer.c_custkey")), + )? + .project(vec![col("orders.o_custkey")])? + .build()?, + ); + + let case_expr = + when(in_subquery(col("customer.c_custkey"), orders), lit("active")) + .otherwise(lit("inactive"))?; + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .project(vec![col("customer.c_custkey"), case_expr])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r#" + Projection: customer.c_custkey, CASE WHEN __correlated_sq_1.mark THEN Utf8("active") ELSE Utf8("inactive") END AS CASE WHEN IN THEN Utf8("active") ELSE Utf8("inactive") END [c_custkey:Int64, CASE WHEN IN THEN Utf8("active") ELSE Utf8("inactive") END:Utf8] + LeftMark Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, mark:Boolean] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: __correlated_sq_1.o_custkey [o_custkey:Int64] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "# + ) + } + + /// Multiple subqueries in one projection expression + #[test] + fn multiple_subqueries_in_one_projection_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let sq1 = test_subquery_with_name("sq_1")?; + let sq2 = test_subquery_with_name("sq_2")?; + + // CASE WHEN a IN (sq1) THEN 'a_match' + // WHEN b IN (sq2) THEN 'b_match' + // ELSE 'none' END + let case_expr = when(in_subquery(col("a"), sq1), lit("a_match")) + .when(in_subquery(col("b"), sq2), lit("b_match")) + .otherwise(lit("none"))?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), case_expr])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a, CASE WHEN __correlated_sq_1.mark THEN Utf8("a_match") WHEN __correlated_sq_2.mark THEN Utf8("b_match") ELSE Utf8("none") END AS CASE WHEN IN THEN Utf8("a_match") WHEN IN THEN Utf8("b_match") ELSE Utf8("none") END [a:UInt32, CASE WHEN IN THEN Utf8("a_match") WHEN IN THEN Utf8("b_match") ELSE Utf8("none") END:Utf8] + LeftMark Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32, mark:Boolean, mark:Boolean] + LeftMark Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + Projection: __correlated_sq_1.c [c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq_1.c [c:UInt32] + TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32] + Projection: __correlated_sq_2.c [c:UInt32] + SubqueryAlias: __correlated_sq_2 [c:UInt32] + Projection: sq_2.c [c:UInt32] + TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32] + "# + ) + } + + /// Projection with no subquery is not modified + #[test] + fn projection_without_subquery_unchanged() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b [a:UInt32, b:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) + } + + /// When a correlated IN subquery inside a projection cannot be decorrelated + /// (e.g. LIMIT in a correlated subquery), the plan is returned unchanged. + #[test] + fn projection_in_subquery_cannot_decorrelate_bails_out() -> Result<()> { + let table_scan = test_table_scan()?; + + // Build a correlated subquery with LIMIT — LIMIT prevents decorrelation + // for IN subqueries (can_pull_up becomes false). + let sq = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) + .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? + .project(vec![col("sq.c")])? + .limit(0, Some(1))? + .build()?, + ); + + let case_expr = when(in_subquery(col("c"), sq), lit("yes")) + .otherwise(lit("no"))?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), case_expr])? + .build()?; + + // Decorrelation fails, plan should be returned unchanged + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a, CASE WHEN test.c IN () THEN Utf8("yes") ELSE Utf8("no") END [a:UInt32, CASE WHEN IN THEN Utf8("yes") ELSE Utf8("no") END:Utf8] + Subquery: [c:UInt32] + Limit: skip=0, fetch=1 [c:UInt32] + Projection: sq.c [c:UInt32] + Filter: outer_ref(test.a) = sq.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + "# + ) + } } diff --git a/datafusion/sqllogictest/test_files/in_subquery_projection.slt b/datafusion/sqllogictest/test_files/in_subquery_projection.slt new file mode 100644 index 0000000000000..843a59012c87a --- /dev/null +++ b/datafusion/sqllogictest/test_files/in_subquery_projection.slt @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############################################### +## Tests for IN/EXISTS subqueries in projection +## expressions (CASE, COALESCE, bare boolean) +############################################### + +# Setup +statement ok +CREATE TABLE t1(id INT, name TEXT) AS VALUES +(1, 'alice'), +(2, 'bob'), +(3, 'carol'), +(4, 'dave'); + +statement ok +CREATE TABLE ref_table(id INT, label TEXT) AS VALUES +(1, 'x'), +(2, 'y'), +(5, 'z'); + +# 1. IN subquery inside CASE WHEN +query TT rowsort +SELECT + name, + CASE WHEN id IN (SELECT id FROM ref_table) THEN 'matched' ELSE 'unmatched' END AS status +FROM t1 +---- +alice matched +bob matched +carol unmatched +dave unmatched + +# 2. NOT IN subquery inside CASE WHEN +query TT rowsort +SELECT + name, + CASE WHEN id NOT IN (SELECT id FROM ref_table) THEN 'excluded' ELSE 'included' END AS status +FROM t1 +---- +alice included +bob included +carol excluded +dave excluded + +# 3. EXISTS inside CASE WHEN (correlated) +query TT rowsort +SELECT + name, + CASE WHEN EXISTS (SELECT 1 FROM ref_table WHERE ref_table.id = t1.id) + THEN 'has_ref' ELSE 'no_ref' END AS ref_status +FROM t1 +---- +alice has_ref +bob has_ref +carol no_ref +dave no_ref + +# 4. NOT EXISTS inside CASE WHEN (correlated) +query TT rowsort +SELECT + name, + CASE WHEN NOT EXISTS (SELECT 1 FROM ref_table WHERE ref_table.id = t1.id) + THEN 'missing' ELSE 'present' END AS ref_status +FROM t1 +---- +alice present +bob present +carol missing +dave missing + +# 5. IN subquery as bare boolean in SELECT +query TB rowsort +SELECT + name, + id IN (SELECT id FROM ref_table) AS is_in_ref +FROM t1 +---- +alice true +bob true +carol false +dave false + +# 6. Correlated IN subquery inside CASE WHEN +query TT rowsort +SELECT + name, + CASE WHEN id IN (SELECT ref_table.id FROM ref_table WHERE ref_table.label = 'x') + THEN 'label_x' ELSE 'other' END AS label_status +FROM t1 +---- +alice label_x +bob other +carol other +dave other + +# 7. Multiple subqueries in one CASE expression +query TT rowsort +SELECT + name, + CASE + WHEN id IN (SELECT id FROM ref_table WHERE label = 'x') THEN 'x_match' + WHEN id IN (SELECT id FROM ref_table WHERE label = 'y') THEN 'y_match' + ELSE 'none' + END AS multi_status +FROM t1 +---- +alice x_match +bob y_match +carol none +dave none + +# 8. COALESCE with IN subquery +query TT rowsort +SELECT + name, + COALESCE( + CASE WHEN id IN (SELECT id FROM ref_table) THEN 'found' ELSE NULL END, + 'default' + ) AS coalesce_status +FROM t1 +---- +alice found +bob found +carol default +dave default + +# Cleanup +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE ref_table; From f57048ff883bdb2cbacccc096c0196a4ac89957d Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Mon, 6 Apr 2026 20:47:51 -0400 Subject: [PATCH 2/5] fix: cargo fmt formatting for CI Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/decorrelate_predicate_subquery.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 83c6477c49c61..dc897266b4ef5 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -33,9 +33,9 @@ use datafusion_common::{ }; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; +use datafusion_expr::logical_plan::Projection; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned}; -use datafusion_expr::logical_plan::Projection; use datafusion_expr::{ BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, exists, in_subquery, lit, not, not_exists, not_in_subquery, @@ -2216,8 +2216,7 @@ mod tests { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( - col("orders.o_custkey") - .eq(out_ref_col(DataType::UInt32, "test.a")), + col("orders.o_custkey").eq(out_ref_col(DataType::UInt32, "test.a")), )? .project(vec![lit(1)])? .build()?, @@ -2308,9 +2307,11 @@ mod tests { .build()?, ); - let case_expr = - when(in_subquery(col("customer.c_custkey"), orders), lit("active")) - .otherwise(lit("inactive"))?; + let case_expr = when( + in_subquery(col("customer.c_custkey"), orders), + lit("active"), + ) + .otherwise(lit("inactive"))?; let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) .project(vec![col("customer.c_custkey"), case_expr])? @@ -2400,8 +2401,8 @@ mod tests { .build()?, ); - let case_expr = when(in_subquery(col("c"), sq), lit("yes")) - .otherwise(lit("no"))?; + let case_expr = + when(in_subquery(col("c"), sq), lit("yes")).otherwise(lit("no"))?; let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), case_expr])? From 12a12bcd1802496f63d3bda17129faa361cd8c0c Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Tue, 7 Apr 2026 15:43:09 -0400 Subject: [PATCH 3/5] Address review feedback from neilconway - Per-expression early bail-out: check has_subquery(&rewritten) immediately after each rewrite instead of processing all exprs then checking at end. Preserves all-or-nothing semantics with earlier termination. - Add "Optimization:" prefix to alias preservation comment to clarify it is not required for correctness. - Add test projection_mixed_decorrelatable_and_non_bails_out covering one decorrelatable + one non-decorrelatable subquery in the same projection (exercises the partial-rewrite-then-bail path). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/decorrelate_predicate_subquery.rs | 99 +++++++++++++++---- 1 file changed, 80 insertions(+), 19 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index dc897266b4ef5..a68288ed675a6 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -131,30 +131,34 @@ impl OptimizerRule for DecorrelatePredicateSubquery { let mut cur_input = Arc::unwrap_or_clone(projection.input); let mut new_exprs = Vec::with_capacity(projection.expr.len()); + // Rewrite each expression in turn. Check per-expression whether + // decorrelation succeeded, and bail out early on the first failure + // (same all-or-nothing semantics as ScalarSubqueryToJoin). for expr in &projection.expr { - if has_subquery(expr) { - let (plan, rewritten) = - rewrite_inner_subqueries(cur_input, expr.clone(), config)?; - cur_input = plan; - new_exprs.push(rewritten); - } else { + if !has_subquery(expr) { new_exprs.push(expr.clone()); + continue; } + let (plan, rewritten) = + rewrite_inner_subqueries(cur_input, expr.clone(), config)?; + if has_subquery(&rewritten) { + // Decorrelation failed for this expression — bail out for + // the whole projection and return the original plan. + let original = Projection::try_new_with_schema( + projection.expr, + original_input, + projection.schema, + )?; + return Ok(Transformed::no(LogicalPlan::Projection(original))); + } + cur_input = plan; + new_exprs.push(rewritten); } - // If any expression still contains a subquery after rewriting, - // decorrelation failed — bail out and return the original plan - // unchanged (same pattern as ScalarSubqueryToJoin). - if new_exprs.iter().any(has_subquery) { - let original = Projection::try_new_with_schema( - projection.expr, - original_input, - projection.schema, - )?; - return Ok(Transformed::no(LogicalPlan::Projection(original))); - } - - // Preserve original column names via aliases where the rewrite changed them + // Optimization: preserve original column names via aliases where + // the rewrite changed them. Not required for correctness — the + // rewritten plan is still valid without aliases — but keeps the + // output column names stable for downstream consumers. let proj_exprs: Vec = projection .expr .iter() @@ -2422,4 +2426,61 @@ mod tests { "# ) } + + /// When a projection contains one decorrelatable subquery and one that + /// cannot be decorrelated, the whole projection should bail out and leave + /// both subqueries in place — neither should be partially rewritten. + #[test] + fn projection_mixed_decorrelatable_and_non_bails_out() -> Result<()> { + let table_scan = test_table_scan()?; + + // Decorrelatable subquery: simple correlated IN with no LIMIT + let sq_ok = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq_ok")?) + .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq_ok.a")))? + .project(vec![col("sq_ok.c")])? + .build()?, + ); + + // Non-decorrelatable subquery: LIMIT prevents decorrelation + let sq_bad = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq_bad")?) + .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq_bad.a")))? + .project(vec![col("sq_bad.c")])? + .limit(0, Some(1))? + .build()?, + ); + + let case_ok = + when(in_subquery(col("c"), sq_ok), lit("ok")).otherwise(lit("not_ok"))?; + let case_bad = + when(in_subquery(col("c"), sq_bad), lit("bad")).otherwise(lit("not_bad"))?; + + // case_ok is listed before case_bad so it would decorrelate first + // if we processed each expression independently. The per-expression + // bail-out must detect case_bad's failure and discard case_ok's + // rewrite too. + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), case_ok, case_bad])? + .build()?; + + // Both subqueries remain — bail-out is all-or-nothing. + // The decorrelatable one is NOT rewritten because its sibling can't be. + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a, CASE WHEN test.c IN () THEN Utf8("ok") ELSE Utf8("not_ok") END, CASE WHEN test.c IN () THEN Utf8("bad") ELSE Utf8("not_bad") END [a:UInt32, CASE WHEN IN THEN Utf8("ok") ELSE Utf8("not_ok") END:Utf8, CASE WHEN IN THEN Utf8("bad") ELSE Utf8("not_bad") END:Utf8] + Subquery: [c:UInt32] + Projection: sq_ok.c [c:UInt32] + Filter: outer_ref(test.a) = sq_ok.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq_ok [a:UInt32, b:UInt32, c:UInt32] + Subquery: [c:UInt32] + Limit: skip=0, fetch=1 [c:UInt32] + Projection: sq_bad.c [c:UInt32] + Filter: outer_ref(test.a) = sq_bad.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq_bad [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + "# + ) + } } From 92884480ac993648cd779813ba4ce03d40bd7a4a Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Tue, 7 Apr 2026 17:22:56 -0400 Subject: [PATCH 4/5] Update stale negative test in predicates.slt The test at predicates.slt:845 asserted that `NULL IN (SELECT ...)` in a projection would fail with "Physical plan does not support logical expression InSubquery". With decorrelation of InSubquery in projections now supported, the query correctly returns (false, true) instead of erroring. Co-Authored-By: Claude Opus 4.6 (1M context) --- datafusion/sqllogictest/test_files/predicates.slt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 8bc2ca465e9b6..86d8034ee8549 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -843,12 +843,15 @@ explain select x from t where x NOT IN (1,2,3,4,5) AND x IN (1,2,3); logical_plan EmptyRelation: rows=0 physical_plan EmptyExec -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression InSubquery\(InSubquery \{ expr: Literal\(Int64\(NULL\), None\), subquery: , negated: false \}\) +# IN and NOT IN against empty relation in a projection (decorrelated via DecorrelatePredicateSubquery) +query BB WITH empty AS (SELECT 10 WHERE false) SELECT NULL IN (SELECT * FROM empty), -- should be false, as the right side is empty relation NULL NOT IN (SELECT * FROM empty) -- should be true, as the right side is empty relation FROM (SELECT 1) t; +---- +false true query I WITH empty AS (SELECT 10 WHERE false) From 93341e03317392c837e44fd258320a55d18d38bb Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Wed, 8 Apr 2026 15:05:15 -0400 Subject: [PATCH 5/5] Address review feedback round 2 from neilconway Previous push addressed the per-expression bail-out correctly but got two items wrong: - Added "Optimization:" prefix to the alias preservation comment instead of the skip check at the top of the Projection arm (the one neilconway actually pointed at). - Did not apply the ScalarSubqueryToJoin-style cloning simplification. This commit fixes both: - Moved the "Optimization:" prefix to the correct line and removed the stray prefix on the alias comment. - Replaced Arc::clone + Arc::unwrap_or_clone with projection.input.as_ref().clone(), mirroring ScalarSubqueryToJoin. - Dropped the original_input variable and the Projection::try_new_with_schema reconstruction on bail-out; the bail-out now returns Ok(Transformed::no(LogicalPlan::Projection( projection))) directly since projection is still fully owned. - Removed the now-unused Projection import. The loop clones projection.expr upfront so it can consume owned expressions without holding a borrow on projection, which lets the bail-out move projection. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/decorrelate_predicate_subquery.rs | 40 ++++++++----------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index a68288ed675a6..9c12e0194c34b 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -33,7 +33,6 @@ use datafusion_common::{ }; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; -use datafusion_expr::logical_plan::Projection; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned}; use datafusion_expr::{ @@ -120,45 +119,40 @@ impl OptimizerRule for DecorrelatePredicateSubquery { Ok(Transformed::yes(cur_input)) } LogicalPlan::Projection(projection) => { - // Skip if no predicate subqueries in any projection expression + // Optimization: skip if no predicate subqueries in any projection expression if !projection.expr.iter().any(has_subquery) { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } - // Keep an Arc clone of the original input so we can reconstruct - // the Projection if decorrelation fails for any expression. - let original_input = Arc::clone(&projection.input); - let mut cur_input = Arc::unwrap_or_clone(projection.input); - let mut new_exprs = Vec::with_capacity(projection.expr.len()); + // Clone projection.expr up front so we can iterate without + // holding a borrow on `projection`, which lets us move it + // directly on bail-out (mirrors ScalarSubqueryToJoin). + let projection_exprs = projection.expr.clone(); + let mut cur_input = projection.input.as_ref().clone(); + let mut new_exprs = Vec::with_capacity(projection_exprs.len()); // Rewrite each expression in turn. Check per-expression whether // decorrelation succeeded, and bail out early on the first failure // (same all-or-nothing semantics as ScalarSubqueryToJoin). - for expr in &projection.expr { - if !has_subquery(expr) { - new_exprs.push(expr.clone()); + for expr in projection_exprs { + if !has_subquery(&expr) { + new_exprs.push(expr); continue; } let (plan, rewritten) = - rewrite_inner_subqueries(cur_input, expr.clone(), config)?; + rewrite_inner_subqueries(cur_input, expr, config)?; if has_subquery(&rewritten) { - // Decorrelation failed for this expression — bail out for - // the whole projection and return the original plan. - let original = Projection::try_new_with_schema( - projection.expr, - original_input, - projection.schema, - )?; - return Ok(Transformed::no(LogicalPlan::Projection(original))); + // Decorrelation failed for this expression — bail out + // for the whole projection. + return Ok(Transformed::no(LogicalPlan::Projection(projection))); } cur_input = plan; new_exprs.push(rewritten); } - // Optimization: preserve original column names via aliases where - // the rewrite changed them. Not required for correctness — the - // rewritten plan is still valid without aliases — but keeps the - // output column names stable for downstream consumers. + // Preserve original column names via aliases where the rewrite + // changed them — keeps output column names stable for + // downstream consumers. let proj_exprs: Vec = projection .expr .iter()