@@ -30,7 +30,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3030use datafusion_common:: { internal_err, plan_err, Column , Result } ;
3131use datafusion_expr:: expr:: { Exists , InSubquery } ;
3232use datafusion_expr:: expr_rewriter:: create_col_from_scalar_expr;
33- use datafusion_expr:: logical_plan:: { JoinType , Projection , Subquery } ;
33+ use datafusion_expr:: logical_plan:: {
34+ Join as LogicalJoin , JoinType , Projection , Subquery ,
35+ } ;
3436use datafusion_expr:: utils:: { conjunction, split_conjunction_owned} ;
3537use datafusion_expr:: {
3638 exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr , Expr , Filter ,
@@ -66,82 +68,166 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
6668 } ) ?
6769 . data ;
6870
69- match plan {
70- LogicalPlan :: Filter ( filter) => {
71- if !has_subquery ( & filter. predicate ) {
72- return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
73- }
71+ // Handle Filters first (existing behavior)
72+ if let LogicalPlan :: Filter ( filter) = plan . clone ( ) {
73+ if !has_subquery ( & filter. predicate ) {
74+ return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
75+ }
7476
75- let ( with_subqueries, mut other_exprs) : ( Vec < _ > , Vec < _ > ) =
76- split_conjunction_owned ( filter. predicate )
77- . into_iter ( )
78- . partition ( has_subquery) ;
77+ let ( with_subqueries, mut other_exprs) : ( Vec < _ > , Vec < _ > ) =
78+ split_conjunction_owned ( filter. predicate )
79+ . into_iter ( )
80+ . partition ( has_subquery) ;
7981
80- if with_subqueries. is_empty ( ) {
81- return internal_err ! (
82- "can not find expected subqueries in DecorrelatePredicateSubquery"
83- ) ;
84- }
82+ if with_subqueries. is_empty ( ) {
83+ return internal_err ! (
84+ "can not find expected subqueries in DecorrelatePredicateSubquery"
85+ ) ;
86+ }
8587
86- // iterate through all exists clauses in predicate, turning each into a join
87- let mut cur_input = Arc :: unwrap_or_clone ( filter. input ) ;
88- for subquery_expr in with_subqueries {
89- match extract_subquery_info ( subquery_expr) {
90- // The subquery expression is at the top level of the filter
91- SubqueryPredicate :: Top ( subquery) => {
92- match build_join_top (
93- & subquery,
94- & cur_input,
95- config. alias_generator ( ) ,
96- ) ? {
97- Some ( plan) => cur_input = plan,
98- // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
99- None => other_exprs. push ( subquery. expr ( ) ) ,
100- }
101- }
102- // The subquery expression is embedded within another expression
103- SubqueryPredicate :: Embedded ( expr) => {
104- let ( plan, expr_without_subqueries) =
105- rewrite_inner_subqueries ( cur_input, expr, config) ?;
106- cur_input = plan;
107- other_exprs. push ( expr_without_subqueries) ;
88+ // iterate through all exists clauses in predicate, turning each into a join
89+ let mut cur_input = Arc :: unwrap_or_clone ( filter. input ) ;
90+ for subquery_expr in with_subqueries {
91+ match extract_subquery_info ( subquery_expr) {
92+ // The subquery expression is at the top level of the filter
93+ SubqueryPredicate :: Top ( subquery) => {
94+ match build_join_top (
95+ & subquery,
96+ & cur_input,
97+ config. alias_generator ( ) ,
98+ ) ? {
99+ Some ( plan) => cur_input = plan,
100+ // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
101+ None => other_exprs. push ( subquery. expr ( ) ) ,
108102 }
109103 }
104+ // The subquery expression is embedded within another expression
105+ SubqueryPredicate :: Embedded ( expr) => {
106+ let ( plan, expr_without_subqueries) =
107+ rewrite_inner_subqueries ( cur_input, expr, config) ?;
108+ cur_input = plan;
109+ other_exprs. push ( expr_without_subqueries) ;
110+ }
110111 }
112+ }
113+
114+ let expr = conjunction ( other_exprs) ;
115+ if let Some ( expr) = expr {
116+ let new_filter = Filter :: try_new ( expr, Arc :: new ( cur_input) ) ?;
117+ return Ok ( Transformed :: yes ( LogicalPlan :: Filter ( new_filter) ) ) ;
118+ }
119+ return Ok ( Transformed :: yes ( cur_input) ) ;
120+ }
111121
112- let expr = conjunction ( other_exprs) ;
113- let mut new_plan = cur_input;
114- if let Some ( expr) = expr {
115- let new_filter = Filter :: try_new ( expr, Arc :: new ( new_plan) ) ?;
116- new_plan = LogicalPlan :: Filter ( new_filter) ;
122+ // Additionally handle subqueries embedded in Join.filter expressions
123+ if let LogicalPlan :: Join ( join) = plan {
124+ if let Some ( predicate) = & join. filter {
125+ if has_subquery ( predicate) {
126+ let ( new_left, new_predicate) = rewrite_inner_subqueries (
127+ Arc :: unwrap_or_clone ( join. left ) ,
128+ predicate. clone ( ) ,
129+ config,
130+ ) ?;
131+
132+ let new_join = LogicalJoin :: try_new (
133+ Arc :: new ( new_left) ,
134+ Arc :: clone ( & join. right ) ,
135+ join. on . clone ( ) ,
136+ Some ( new_predicate) ,
137+ join. join_type ,
138+ join. join_constraint ,
139+ join. null_equals_null ,
140+ ) ?;
141+ return Ok ( Transformed :: yes ( LogicalPlan :: Join ( new_join) ) ) ;
117142 }
118- Ok ( Transformed :: yes ( new_plan) )
119143 }
120- LogicalPlan :: Projection ( proj) => {
121- // Only proceed if any projection expression contains a subquery
122- if !proj. expr . iter ( ) . any ( has_subquery) {
123- return Ok ( Transformed :: no ( LogicalPlan :: Projection ( proj) ) ) ;
144+ return Ok ( Transformed :: no ( LogicalPlan :: Join ( join) ) ) ;
145+ }
146+
147+ // Handle subqueries embedded in Aggregate group/aggregate expressions
148+ if let LogicalPlan :: Aggregate ( aggregate) = plan {
149+ let mut needs_rewrite = false ;
150+ for e in & aggregate. group_expr {
151+ if has_subquery ( e) {
152+ needs_rewrite = true ;
153+ break ;
154+ }
155+ }
156+ if !needs_rewrite {
157+ for e in & aggregate. aggr_expr {
158+ if has_subquery ( e) {
159+ needs_rewrite = true ;
160+ break ;
161+ }
124162 }
163+ }
164+ if !needs_rewrite {
165+ return Ok ( Transformed :: no ( LogicalPlan :: Aggregate ( aggregate) ) ) ;
166+ }
125167
126- let mut cur_input = Arc :: unwrap_or_clone ( proj. input ) ;
127- let mut new_exprs = Vec :: with_capacity ( proj. expr . len ( ) ) ;
128- for e in proj. expr {
129- let old_name = e. schema_name ( ) . to_string ( ) ;
130- let ( plan_after, rewritten) =
131- rewrite_inner_subqueries ( cur_input, e, config) ?;
132- cur_input = plan_after;
133- let new_name = rewritten. schema_name ( ) . to_string ( ) ;
168+ let mut cur_input = Arc :: unwrap_or_clone ( aggregate. input ) ;
169+ let mut new_group_exprs = Vec :: with_capacity ( aggregate. group_expr . len ( ) ) ;
170+ for expr in aggregate. group_expr {
171+ if has_subquery ( & expr) {
172+ let ( next_input, rewritten_expr) =
173+ rewrite_inner_subqueries ( cur_input, expr, config) ?;
174+ cur_input = next_input;
175+ new_group_exprs. push ( rewritten_expr) ;
176+ } else {
177+ new_group_exprs. push ( expr) ;
178+ }
179+ }
180+ let mut new_aggr_exprs = Vec :: with_capacity ( aggregate. aggr_expr . len ( ) ) ;
181+ for expr in aggregate. aggr_expr {
182+ if has_subquery ( & expr) {
183+ let old_name = expr. schema_name ( ) . to_string ( ) ;
184+ let ( next_input, rewritten_expr) =
185+ rewrite_inner_subqueries ( cur_input, expr, config) ?;
186+ cur_input = next_input;
187+ let new_name = rewritten_expr. schema_name ( ) . to_string ( ) ;
134188 if new_name != old_name {
135- new_exprs . push ( rewritten . alias ( old_name) ) ;
189+ new_aggr_exprs . push ( rewritten_expr . alias ( old_name) ) ;
136190 } else {
137- new_exprs . push ( rewritten ) ;
191+ new_aggr_exprs . push ( rewritten_expr ) ;
138192 }
193+ } else {
194+ new_aggr_exprs. push ( expr) ;
139195 }
140- let new_proj = Projection :: try_new ( new_exprs, Arc :: new ( cur_input) ) ?;
141- Ok ( Transformed :: yes ( LogicalPlan :: Projection ( new_proj) ) )
142196 }
143- other => Ok ( Transformed :: no ( other) ) ,
197+
198+ let new_plan = LogicalPlanBuilder :: from ( cur_input)
199+ . aggregate ( new_group_exprs, new_aggr_exprs) ?
200+ . build ( ) ?;
201+ return Ok ( Transformed :: yes ( new_plan) ) ;
144202 }
203+
204+ // Handle Projection nodes with subqueries in expressions
205+ if let LogicalPlan :: Projection ( proj) = plan {
206+ // Only proceed if any projection expression contains a subquery
207+ if !proj. expr . iter ( ) . any ( has_subquery) {
208+ return Ok ( Transformed :: no ( LogicalPlan :: Projection ( proj) ) ) ;
209+ }
210+
211+ let mut cur_input = Arc :: unwrap_or_clone ( proj. input ) ;
212+ let mut new_exprs = Vec :: with_capacity ( proj. expr . len ( ) ) ;
213+ for e in proj. expr {
214+ let old_name = e. schema_name ( ) . to_string ( ) ;
215+ let ( plan_after, rewritten) =
216+ rewrite_inner_subqueries ( cur_input, e, config) ?;
217+ cur_input = plan_after;
218+ let new_name = rewritten. schema_name ( ) . to_string ( ) ;
219+ if new_name != old_name {
220+ new_exprs. push ( rewritten. alias ( old_name) ) ;
221+ } else {
222+ new_exprs. push ( rewritten) ;
223+ }
224+ }
225+ let new_proj = Projection :: try_new ( new_exprs, Arc :: new ( cur_input) ) ?;
226+ return Ok ( Transformed :: yes ( LogicalPlan :: Projection ( new_proj) ) ) ;
227+ }
228+
229+ // Other plans unchanged
230+ Ok ( Transformed :: no ( plan) )
145231 }
146232
147233 fn name ( & self ) -> & str {
@@ -477,6 +563,45 @@ mod tests {
477563 ) )
478564 }
479565
566+ /// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate
567+ #[ test]
568+ fn aggregate_case_in_subquery ( ) -> Result < ( ) > {
569+ let table_scan = test_table_scan_with_name ( "distinct_source" ) ?;
570+ use datafusion_expr:: expr_fn:: when;
571+ use datafusion_functions_aggregate:: expr_fn:: max as agg_max;
572+
573+ let agg_b: Expr = agg_max ( col ( "distinct_source.b" ) ) ;
574+ let subq = LogicalPlanBuilder :: from ( table_scan. clone ( ) )
575+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ agg_b] ) ?
576+ . project ( vec ! [ col( "max(distinct_source.b)" ) ] ) ?
577+ . build ( ) ?;
578+
579+ let case_expr = when (
580+ in_subquery ( col ( "distinct_source.b" ) , Arc :: new ( subq) ) ,
581+ lit ( 1 ) ,
582+ )
583+ . otherwise ( lit ( 0 ) ) ?;
584+
585+ let plan = LogicalPlanBuilder :: from ( table_scan)
586+ . aggregate (
587+ vec ! [ col( "distinct_source.a" ) . alias( "primary_key" ) ] ,
588+ vec ! [
589+ agg_max( case_expr) . alias( "is_in_most_recent_task" ) ,
590+ agg_max( col( "distinct_source.c" ) ) . alias( "max_timestamp" ) ,
591+ ] ,
592+ ) ?
593+ . build ( ) ?;
594+
595+ use crate :: { OptimizerContext , OptimizerRule } ;
596+ let optimized = DecorrelatePredicateSubquery :: new ( )
597+ . rewrite ( plan, & OptimizerContext :: new ( ) ) ?
598+ . data ;
599+ let lp = optimized. display_indent ( ) . to_string ( ) ;
600+ assert ! ( lp. contains( "Aggregate:" ) ) ;
601+ assert ! ( lp. contains( "Left" ) ) ;
602+ Ok ( ( ) )
603+ }
604+
480605 /// Test for several IN subquery expressions
481606 #[ test]
482607 fn in_subquery_multiple ( ) -> Result < ( ) > {
0 commit comments