Skip to content

Commit d13b7a9

Browse files
committed
Decorellate subqueries in IN inside JOIN filter and Aggregates
1 parent 4250bd1 commit d13b7a9

2 files changed

Lines changed: 335 additions & 60 deletions

File tree

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 185 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3030
use datafusion_common::{internal_err, plan_err, Column, Result};
3131
use datafusion_expr::expr::{Exists, InSubquery};
3232
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
33-
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
33+
use datafusion_expr::logical_plan::{
34+
Join as LogicalJoin, JoinType, Projection, Subquery,
35+
};
3436
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
3537
use datafusion_expr::{
3638
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
@@ -66,82 +68,166 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
6668
})?
6769
.data;
6870

69-
match plan {
70-
LogicalPlan::Filter(filter) => {
71-
if !has_subquery(&filter.predicate) {
72-
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
73-
}
71+
// Handle Filters first (existing behavior)
72+
if let LogicalPlan::Filter(filter) = plan.clone() {
73+
if !has_subquery(&filter.predicate) {
74+
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
75+
}
7476

75-
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
76-
split_conjunction_owned(filter.predicate)
77-
.into_iter()
78-
.partition(has_subquery);
77+
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
78+
split_conjunction_owned(filter.predicate)
79+
.into_iter()
80+
.partition(has_subquery);
7981

80-
if with_subqueries.is_empty() {
81-
return internal_err!(
82-
"can not find expected subqueries in DecorrelatePredicateSubquery"
83-
);
84-
}
82+
if with_subqueries.is_empty() {
83+
return internal_err!(
84+
"can not find expected subqueries in DecorrelatePredicateSubquery"
85+
);
86+
}
8587

86-
// iterate through all exists clauses in predicate, turning each into a join
87-
let mut cur_input = Arc::unwrap_or_clone(filter.input);
88-
for subquery_expr in with_subqueries {
89-
match extract_subquery_info(subquery_expr) {
90-
// The subquery expression is at the top level of the filter
91-
SubqueryPredicate::Top(subquery) => {
92-
match build_join_top(
93-
&subquery,
94-
&cur_input,
95-
config.alias_generator(),
96-
)? {
97-
Some(plan) => cur_input = plan,
98-
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
99-
None => other_exprs.push(subquery.expr()),
100-
}
101-
}
102-
// The subquery expression is embedded within another expression
103-
SubqueryPredicate::Embedded(expr) => {
104-
let (plan, expr_without_subqueries) =
105-
rewrite_inner_subqueries(cur_input, expr, config)?;
106-
cur_input = plan;
107-
other_exprs.push(expr_without_subqueries);
88+
// iterate through all exists clauses in predicate, turning each into a join
89+
let mut cur_input = Arc::unwrap_or_clone(filter.input);
90+
for subquery_expr in with_subqueries {
91+
match extract_subquery_info(subquery_expr) {
92+
// The subquery expression is at the top level of the filter
93+
SubqueryPredicate::Top(subquery) => {
94+
match build_join_top(
95+
&subquery,
96+
&cur_input,
97+
config.alias_generator(),
98+
)? {
99+
Some(plan) => cur_input = plan,
100+
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
101+
None => other_exprs.push(subquery.expr()),
108102
}
109103
}
104+
// The subquery expression is embedded within another expression
105+
SubqueryPredicate::Embedded(expr) => {
106+
let (plan, expr_without_subqueries) =
107+
rewrite_inner_subqueries(cur_input, expr, config)?;
108+
cur_input = plan;
109+
other_exprs.push(expr_without_subqueries);
110+
}
110111
}
112+
}
113+
114+
let expr = conjunction(other_exprs);
115+
if let Some(expr) = expr {
116+
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
117+
return Ok(Transformed::yes(LogicalPlan::Filter(new_filter)));
118+
}
119+
return Ok(Transformed::yes(cur_input));
120+
}
111121

112-
let expr = conjunction(other_exprs);
113-
let mut new_plan = cur_input;
114-
if let Some(expr) = expr {
115-
let new_filter = Filter::try_new(expr, Arc::new(new_plan))?;
116-
new_plan = LogicalPlan::Filter(new_filter);
122+
// Additionally handle subqueries embedded in Join.filter expressions
123+
if let LogicalPlan::Join(join) = plan {
124+
if let Some(predicate) = &join.filter {
125+
if has_subquery(predicate) {
126+
let (new_left, new_predicate) = rewrite_inner_subqueries(
127+
Arc::unwrap_or_clone(join.left),
128+
predicate.clone(),
129+
config,
130+
)?;
131+
132+
let new_join = LogicalJoin::try_new(
133+
Arc::new(new_left),
134+
Arc::clone(&join.right),
135+
join.on.clone(),
136+
Some(new_predicate),
137+
join.join_type,
138+
join.join_constraint,
139+
join.null_equals_null,
140+
)?;
141+
return Ok(Transformed::yes(LogicalPlan::Join(new_join)));
117142
}
118-
Ok(Transformed::yes(new_plan))
119143
}
120-
LogicalPlan::Projection(proj) => {
121-
// Only proceed if any projection expression contains a subquery
122-
if !proj.expr.iter().any(has_subquery) {
123-
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
144+
return Ok(Transformed::no(LogicalPlan::Join(join)));
145+
}
146+
147+
// Handle subqueries embedded in Aggregate group/aggregate expressions
148+
if let LogicalPlan::Aggregate(aggregate) = plan {
149+
let mut needs_rewrite = false;
150+
for e in &aggregate.group_expr {
151+
if has_subquery(e) {
152+
needs_rewrite = true;
153+
break;
154+
}
155+
}
156+
if !needs_rewrite {
157+
for e in &aggregate.aggr_expr {
158+
if has_subquery(e) {
159+
needs_rewrite = true;
160+
break;
161+
}
124162
}
163+
}
164+
if !needs_rewrite {
165+
return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
166+
}
125167

126-
let mut cur_input = Arc::unwrap_or_clone(proj.input);
127-
let mut new_exprs = Vec::with_capacity(proj.expr.len());
128-
for e in proj.expr {
129-
let old_name = e.schema_name().to_string();
130-
let (plan_after, rewritten) =
131-
rewrite_inner_subqueries(cur_input, e, config)?;
132-
cur_input = plan_after;
133-
let new_name = rewritten.schema_name().to_string();
168+
let mut cur_input = Arc::unwrap_or_clone(aggregate.input);
169+
let mut new_group_exprs = Vec::with_capacity(aggregate.group_expr.len());
170+
for expr in aggregate.group_expr {
171+
if has_subquery(&expr) {
172+
let (next_input, rewritten_expr) =
173+
rewrite_inner_subqueries(cur_input, expr, config)?;
174+
cur_input = next_input;
175+
new_group_exprs.push(rewritten_expr);
176+
} else {
177+
new_group_exprs.push(expr);
178+
}
179+
}
180+
let mut new_aggr_exprs = Vec::with_capacity(aggregate.aggr_expr.len());
181+
for expr in aggregate.aggr_expr {
182+
if has_subquery(&expr) {
183+
let old_name = expr.schema_name().to_string();
184+
let (next_input, rewritten_expr) =
185+
rewrite_inner_subqueries(cur_input, expr, config)?;
186+
cur_input = next_input;
187+
let new_name = rewritten_expr.schema_name().to_string();
134188
if new_name != old_name {
135-
new_exprs.push(rewritten.alias(old_name));
189+
new_aggr_exprs.push(rewritten_expr.alias(old_name));
136190
} else {
137-
new_exprs.push(rewritten);
191+
new_aggr_exprs.push(rewritten_expr);
138192
}
193+
} else {
194+
new_aggr_exprs.push(expr);
139195
}
140-
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
141-
Ok(Transformed::yes(LogicalPlan::Projection(new_proj)))
142196
}
143-
other => Ok(Transformed::no(other)),
197+
198+
let new_plan = LogicalPlanBuilder::from(cur_input)
199+
.aggregate(new_group_exprs, new_aggr_exprs)?
200+
.build()?;
201+
return Ok(Transformed::yes(new_plan));
144202
}
203+
204+
// Handle Projection nodes with subqueries in expressions
205+
if let LogicalPlan::Projection(proj) = plan {
206+
// Only proceed if any projection expression contains a subquery
207+
if !proj.expr.iter().any(has_subquery) {
208+
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
209+
}
210+
211+
let mut cur_input = Arc::unwrap_or_clone(proj.input);
212+
let mut new_exprs = Vec::with_capacity(proj.expr.len());
213+
for e in proj.expr {
214+
let old_name = e.schema_name().to_string();
215+
let (plan_after, rewritten) =
216+
rewrite_inner_subqueries(cur_input, e, config)?;
217+
cur_input = plan_after;
218+
let new_name = rewritten.schema_name().to_string();
219+
if new_name != old_name {
220+
new_exprs.push(rewritten.alias(old_name));
221+
} else {
222+
new_exprs.push(rewritten);
223+
}
224+
}
225+
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
226+
return Ok(Transformed::yes(LogicalPlan::Projection(new_proj)));
227+
}
228+
229+
// Other plans unchanged
230+
Ok(Transformed::no(plan))
145231
}
146232

147233
fn name(&self) -> &str {
@@ -477,6 +563,45 @@ mod tests {
477563
))
478564
}
479565

566+
/// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate
567+
#[test]
568+
fn aggregate_case_in_subquery() -> Result<()> {
569+
let table_scan = test_table_scan_with_name("distinct_source")?;
570+
use datafusion_expr::expr_fn::when;
571+
use datafusion_functions_aggregate::expr_fn::max as agg_max;
572+
573+
let agg_b: Expr = agg_max(col("distinct_source.b"));
574+
let subq = LogicalPlanBuilder::from(table_scan.clone())
575+
.aggregate(Vec::<Expr>::new(), vec![agg_b])?
576+
.project(vec![col("max(distinct_source.b)")])?
577+
.build()?;
578+
579+
let case_expr = when(
580+
in_subquery(col("distinct_source.b"), Arc::new(subq)),
581+
lit(1),
582+
)
583+
.otherwise(lit(0))?;
584+
585+
let plan = LogicalPlanBuilder::from(table_scan)
586+
.aggregate(
587+
vec![col("distinct_source.a").alias("primary_key")],
588+
vec![
589+
agg_max(case_expr).alias("is_in_most_recent_task"),
590+
agg_max(col("distinct_source.c")).alias("max_timestamp"),
591+
],
592+
)?
593+
.build()?;
594+
595+
use crate::{OptimizerContext, OptimizerRule};
596+
let optimized = DecorrelatePredicateSubquery::new()
597+
.rewrite(plan, &OptimizerContext::new())?
598+
.data;
599+
let lp = optimized.display_indent().to_string();
600+
assert!(lp.contains("Aggregate:"));
601+
assert!(lp.contains("Left"));
602+
Ok(())
603+
}
604+
480605
/// Test for several IN subquery expressions
481606
#[test]
482607
fn in_subquery_multiple() -> Result<()> {

0 commit comments

Comments
 (0)