Skip to content

Commit b48c615

Browse files
committed
Decorellate subqueries in IN inside JOIN filter and Aggregates
1 parent 4250bd1 commit b48c615

2 files changed

Lines changed: 315 additions & 60 deletions

File tree

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 165 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3030
use datafusion_common::{internal_err, plan_err, Column, Result};
3131
use datafusion_expr::expr::{Exists, InSubquery};
3232
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
33-
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
33+
use datafusion_expr::logical_plan::{Join as LogicalJoin, JoinType, Projection, Subquery};
3434
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
3535
use datafusion_expr::{
3636
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
@@ -66,82 +66,151 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
6666
})?
6767
.data;
6868

69-
match plan {
70-
LogicalPlan::Filter(filter) => {
71-
if !has_subquery(&filter.predicate) {
72-
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
73-
}
69+
// Handle Filters first (existing behavior)
70+
if let LogicalPlan::Filter(filter) = plan.clone() {
71+
if !has_subquery(&filter.predicate) {
72+
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
73+
}
7474

75-
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
76-
split_conjunction_owned(filter.predicate)
77-
.into_iter()
78-
.partition(has_subquery);
75+
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
76+
split_conjunction_owned(filter.predicate)
77+
.into_iter()
78+
.partition(has_subquery);
7979

80-
if with_subqueries.is_empty() {
81-
return internal_err!(
82-
"can not find expected subqueries in DecorrelatePredicateSubquery"
83-
);
84-
}
80+
if with_subqueries.is_empty() {
81+
return internal_err!(
82+
"can not find expected subqueries in DecorrelatePredicateSubquery"
83+
);
84+
}
8585

86-
// iterate through all exists clauses in predicate, turning each into a join
87-
let mut cur_input = Arc::unwrap_or_clone(filter.input);
88-
for subquery_expr in with_subqueries {
89-
match extract_subquery_info(subquery_expr) {
90-
// The subquery expression is at the top level of the filter
91-
SubqueryPredicate::Top(subquery) => {
92-
match build_join_top(
93-
&subquery,
94-
&cur_input,
95-
config.alias_generator(),
96-
)? {
97-
Some(plan) => cur_input = plan,
98-
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
99-
None => other_exprs.push(subquery.expr()),
100-
}
101-
}
102-
// The subquery expression is embedded within another expression
103-
SubqueryPredicate::Embedded(expr) => {
104-
let (plan, expr_without_subqueries) =
105-
rewrite_inner_subqueries(cur_input, expr, config)?;
106-
cur_input = plan;
107-
other_exprs.push(expr_without_subqueries);
86+
// iterate through all exists clauses in predicate, turning each into a join
87+
let mut cur_input = Arc::unwrap_or_clone(filter.input);
88+
for subquery_expr in with_subqueries {
89+
match extract_subquery_info(subquery_expr) {
90+
// The subquery expression is at the top level of the filter
91+
SubqueryPredicate::Top(subquery) => {
92+
match build_join_top(&subquery, &cur_input, config.alias_generator())? {
93+
Some(plan) => cur_input = plan,
94+
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
95+
None => other_exprs.push(subquery.expr()),
10896
}
10997
}
98+
// The subquery expression is embedded within another expression
99+
SubqueryPredicate::Embedded(expr) => {
100+
let (plan, expr_without_subqueries) =
101+
rewrite_inner_subqueries(cur_input, expr, config)?;
102+
cur_input = plan;
103+
other_exprs.push(expr_without_subqueries);
104+
}
110105
}
106+
}
107+
108+
let expr = conjunction(other_exprs);
109+
if let Some(expr) = expr {
110+
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
111+
return Ok(Transformed::yes(LogicalPlan::Filter(new_filter)));
112+
}
113+
return Ok(Transformed::yes(cur_input));
114+
}
111115

112-
let expr = conjunction(other_exprs);
113-
let mut new_plan = cur_input;
114-
if let Some(expr) = expr {
115-
let new_filter = Filter::try_new(expr, Arc::new(new_plan))?;
116-
new_plan = LogicalPlan::Filter(new_filter);
116+
// Additionally handle subqueries embedded in Join.filter expressions
117+
if let LogicalPlan::Join(join) = plan {
118+
if let Some(predicate) = &join.filter {
119+
if has_subquery(predicate) {
120+
let (new_left, new_predicate) =
121+
rewrite_inner_subqueries(Arc::unwrap_or_clone(join.left), predicate.clone(), config)?;
122+
123+
let new_join = LogicalJoin::try_new(
124+
Arc::new(new_left),
125+
Arc::clone(&join.right),
126+
join.on.clone(),
127+
Some(new_predicate),
128+
join.join_type,
129+
join.join_constraint,
130+
join.null_equals_null,
131+
)?;
132+
return Ok(Transformed::yes(LogicalPlan::Join(new_join)));
117133
}
118-
Ok(Transformed::yes(new_plan))
119134
}
120-
LogicalPlan::Projection(proj) => {
121-
// Only proceed if any projection expression contains a subquery
122-
if !proj.expr.iter().any(has_subquery) {
123-
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
135+
return Ok(Transformed::no(LogicalPlan::Join(join)));
136+
}
137+
138+
// Handle subqueries embedded in Aggregate group/aggregate expressions
139+
if let LogicalPlan::Aggregate(aggregate) = plan {
140+
let mut needs_rewrite = false;
141+
for e in &aggregate.group_expr {
142+
if has_subquery(e) { needs_rewrite = true; break; }
143+
}
144+
if !needs_rewrite {
145+
for e in &aggregate.aggr_expr {
146+
if has_subquery(e) { needs_rewrite = true; break; }
124147
}
148+
}
149+
if !needs_rewrite {
150+
return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
151+
}
125152

126-
let mut cur_input = Arc::unwrap_or_clone(proj.input);
127-
let mut new_exprs = Vec::with_capacity(proj.expr.len());
128-
for e in proj.expr {
129-
let old_name = e.schema_name().to_string();
130-
let (plan_after, rewritten) =
131-
rewrite_inner_subqueries(cur_input, e, config)?;
132-
cur_input = plan_after;
133-
let new_name = rewritten.schema_name().to_string();
153+
let mut cur_input = Arc::unwrap_or_clone(aggregate.input);
154+
let mut new_group_exprs = Vec::with_capacity(aggregate.group_expr.len());
155+
for expr in aggregate.group_expr {
156+
if has_subquery(&expr) {
157+
let (next_input, rewritten_expr) = rewrite_inner_subqueries(cur_input, expr, config)?;
158+
cur_input = next_input;
159+
new_group_exprs.push(rewritten_expr);
160+
} else {
161+
new_group_exprs.push(expr);
162+
}
163+
}
164+
let mut new_aggr_exprs = Vec::with_capacity(aggregate.aggr_expr.len());
165+
for expr in aggregate.aggr_expr {
166+
if has_subquery(&expr) {
167+
let old_name = expr.schema_name().to_string();
168+
let (next_input, rewritten_expr) = rewrite_inner_subqueries(cur_input, expr, config)?;
169+
cur_input = next_input;
170+
let new_name = rewritten_expr.schema_name().to_string();
134171
if new_name != old_name {
135-
new_exprs.push(rewritten.alias(old_name));
172+
new_aggr_exprs.push(rewritten_expr.alias(old_name));
136173
} else {
137-
new_exprs.push(rewritten);
174+
new_aggr_exprs.push(rewritten_expr);
138175
}
176+
} else {
177+
new_aggr_exprs.push(expr);
178+
}
179+
}
180+
181+
let new_plan = LogicalPlanBuilder::from(cur_input)
182+
.aggregate(new_group_exprs, new_aggr_exprs)?
183+
.build()?;
184+
return Ok(Transformed::yes(new_plan));
185+
}
186+
187+
// Handle Projection nodes with subqueries in expressions
188+
if let LogicalPlan::Projection(proj) = plan {
189+
// Only proceed if any projection expression contains a subquery
190+
if !proj.expr.iter().any(has_subquery) {
191+
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
192+
}
193+
194+
let mut cur_input = Arc::unwrap_or_clone(proj.input);
195+
let mut new_exprs = Vec::with_capacity(proj.expr.len());
196+
for e in proj.expr {
197+
let old_name = e.schema_name().to_string();
198+
let (plan_after, rewritten) =
199+
rewrite_inner_subqueries(cur_input, e, config)?;
200+
cur_input = plan_after;
201+
let new_name = rewritten.schema_name().to_string();
202+
if new_name != old_name {
203+
new_exprs.push(rewritten.alias(old_name));
204+
} else {
205+
new_exprs.push(rewritten);
139206
}
140-
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
141-
Ok(Transformed::yes(LogicalPlan::Projection(new_proj)))
142207
}
143-
other => Ok(Transformed::no(other)),
208+
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
209+
return Ok(Transformed::yes(LogicalPlan::Projection(new_proj)));
144210
}
211+
212+
// Other plans unchanged
213+
Ok(Transformed::no(plan))
145214
}
146215

147216
fn name(&self) -> &str {
@@ -477,6 +546,42 @@ mod tests {
477546
))
478547
}
479548

549+
/// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate
550+
#[test]
551+
fn aggregate_case_in_subquery() -> Result<()> {
552+
let table_scan = test_table_scan_with_name("distinct_source")?;
553+
use datafusion_functions_aggregate::expr_fn::max as agg_max;
554+
use datafusion_expr::expr_fn::when;
555+
556+
let agg_b: Expr = agg_max(col("distinct_source.b"));
557+
let subq = LogicalPlanBuilder::from(table_scan.clone())
558+
.aggregate(Vec::<Expr>::new(), vec![agg_b])?
559+
.project(vec![col("max(distinct_source.b)")])?
560+
.build()?;
561+
562+
let case_expr = when(in_subquery(col("distinct_source.b"), Arc::new(subq)), lit(1))
563+
.otherwise(lit(0))?;
564+
565+
let plan = LogicalPlanBuilder::from(table_scan)
566+
.aggregate(
567+
vec![col("distinct_source.a").alias("primary_key")],
568+
vec![
569+
agg_max(case_expr).alias("is_in_most_recent_task"),
570+
agg_max(col("distinct_source.c")).alias("max_timestamp"),
571+
],
572+
)?
573+
.build()?;
574+
575+
use crate::{OptimizerContext, OptimizerRule};
576+
let optimized = DecorrelatePredicateSubquery::new()
577+
.rewrite(plan, &OptimizerContext::new())?
578+
.data;
579+
let lp = optimized.display_indent().to_string();
580+
assert!(lp.contains("Aggregate:"));
581+
assert!(lp.contains("Left"));
582+
Ok(())
583+
}
584+
480585
/// Test for several IN subquery expressions
481586
#[test]
482587
fn in_subquery_multiple() -> Result<()> {

0 commit comments

Comments
 (0)