@@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3030use datafusion_common:: { internal_err, plan_err, Column , Result } ;
3131use datafusion_expr:: expr:: { Exists , InSubquery } ;
3232use datafusion_expr:: expr_rewriter:: create_col_from_scalar_expr;
33- use datafusion_expr:: logical_plan:: { JoinType , Projection , Subquery } ;
33+ use datafusion_expr:: logical_plan:: { Join as LogicalJoin , JoinType , Projection , Subquery } ;
3434use datafusion_expr:: utils:: { conjunction, split_conjunction_owned} ;
3535use datafusion_expr:: {
3636 exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr , Expr , Filter ,
@@ -66,82 +66,151 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
6666 } ) ?
6767 . data ;
6868
69- match plan {
70- LogicalPlan :: Filter ( filter) => {
71- if !has_subquery ( & filter. predicate ) {
72- return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
73- }
69+ // Handle Filters first (existing behavior)
70+ if let LogicalPlan :: Filter ( filter) = plan . clone ( ) {
71+ if !has_subquery ( & filter. predicate ) {
72+ return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
73+ }
7474
75- let ( with_subqueries, mut other_exprs) : ( Vec < _ > , Vec < _ > ) =
76- split_conjunction_owned ( filter. predicate )
77- . into_iter ( )
78- . partition ( has_subquery) ;
75+ let ( with_subqueries, mut other_exprs) : ( Vec < _ > , Vec < _ > ) =
76+ split_conjunction_owned ( filter. predicate )
77+ . into_iter ( )
78+ . partition ( has_subquery) ;
7979
80- if with_subqueries. is_empty ( ) {
81- return internal_err ! (
82- "can not find expected subqueries in DecorrelatePredicateSubquery"
83- ) ;
84- }
80+ if with_subqueries. is_empty ( ) {
81+ return internal_err ! (
82+ "can not find expected subqueries in DecorrelatePredicateSubquery"
83+ ) ;
84+ }
8585
86- // iterate through all exists clauses in predicate, turning each into a join
87- let mut cur_input = Arc :: unwrap_or_clone ( filter. input ) ;
88- for subquery_expr in with_subqueries {
89- match extract_subquery_info ( subquery_expr) {
90- // The subquery expression is at the top level of the filter
91- SubqueryPredicate :: Top ( subquery) => {
92- match build_join_top (
93- & subquery,
94- & cur_input,
95- config. alias_generator ( ) ,
96- ) ? {
97- Some ( plan) => cur_input = plan,
98- // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
99- None => other_exprs. push ( subquery. expr ( ) ) ,
100- }
101- }
102- // The subquery expression is embedded within another expression
103- SubqueryPredicate :: Embedded ( expr) => {
104- let ( plan, expr_without_subqueries) =
105- rewrite_inner_subqueries ( cur_input, expr, config) ?;
106- cur_input = plan;
107- other_exprs. push ( expr_without_subqueries) ;
86+ // iterate through all exists clauses in predicate, turning each into a join
87+ let mut cur_input = Arc :: unwrap_or_clone ( filter. input ) ;
88+ for subquery_expr in with_subqueries {
89+ match extract_subquery_info ( subquery_expr) {
90+ // The subquery expression is at the top level of the filter
91+ SubqueryPredicate :: Top ( subquery) => {
92+ match build_join_top ( & subquery, & cur_input, config. alias_generator ( ) ) ? {
93+ Some ( plan) => cur_input = plan,
94+ // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
95+ None => other_exprs. push ( subquery. expr ( ) ) ,
10896 }
10997 }
98+ // The subquery expression is embedded within another expression
99+ SubqueryPredicate :: Embedded ( expr) => {
100+ let ( plan, expr_without_subqueries) =
101+ rewrite_inner_subqueries ( cur_input, expr, config) ?;
102+ cur_input = plan;
103+ other_exprs. push ( expr_without_subqueries) ;
104+ }
110105 }
106+ }
107+
108+ let expr = conjunction ( other_exprs) ;
109+ if let Some ( expr) = expr {
110+ let new_filter = Filter :: try_new ( expr, Arc :: new ( cur_input) ) ?;
111+ return Ok ( Transformed :: yes ( LogicalPlan :: Filter ( new_filter) ) ) ;
112+ }
113+ return Ok ( Transformed :: yes ( cur_input) ) ;
114+ }
111115
112- let expr = conjunction ( other_exprs) ;
113- let mut new_plan = cur_input;
114- if let Some ( expr) = expr {
115- let new_filter = Filter :: try_new ( expr, Arc :: new ( new_plan) ) ?;
116- new_plan = LogicalPlan :: Filter ( new_filter) ;
116+ // Additionally handle subqueries embedded in Join.filter expressions
117+ if let LogicalPlan :: Join ( join) = plan {
118+ if let Some ( predicate) = & join. filter {
119+ if has_subquery ( predicate) {
120+ let ( new_left, new_predicate) =
121+ rewrite_inner_subqueries ( Arc :: unwrap_or_clone ( join. left ) , predicate. clone ( ) , config) ?;
122+
123+ let new_join = LogicalJoin :: try_new (
124+ Arc :: new ( new_left) ,
125+ Arc :: clone ( & join. right ) ,
126+ join. on . clone ( ) ,
127+ Some ( new_predicate) ,
128+ join. join_type ,
129+ join. join_constraint ,
130+ join. null_equals_null ,
131+ ) ?;
132+ return Ok ( Transformed :: yes ( LogicalPlan :: Join ( new_join) ) ) ;
117133 }
118- Ok ( Transformed :: yes ( new_plan) )
119134 }
120- LogicalPlan :: Projection ( proj) => {
121- // Only proceed if any projection expression contains a subquery
122- if !proj. expr . iter ( ) . any ( has_subquery) {
123- return Ok ( Transformed :: no ( LogicalPlan :: Projection ( proj) ) ) ;
135+ return Ok ( Transformed :: no ( LogicalPlan :: Join ( join) ) ) ;
136+ }
137+
138+ // Handle subqueries embedded in Aggregate group/aggregate expressions
139+ if let LogicalPlan :: Aggregate ( aggregate) = plan {
140+ let mut needs_rewrite = false ;
141+ for e in & aggregate. group_expr {
142+ if has_subquery ( e) { needs_rewrite = true ; break ; }
143+ }
144+ if !needs_rewrite {
145+ for e in & aggregate. aggr_expr {
146+ if has_subquery ( e) { needs_rewrite = true ; break ; }
124147 }
148+ }
149+ if !needs_rewrite {
150+ return Ok ( Transformed :: no ( LogicalPlan :: Aggregate ( aggregate) ) ) ;
151+ }
125152
126- let mut cur_input = Arc :: unwrap_or_clone ( proj. input ) ;
127- let mut new_exprs = Vec :: with_capacity ( proj. expr . len ( ) ) ;
128- for e in proj. expr {
129- let old_name = e. schema_name ( ) . to_string ( ) ;
130- let ( plan_after, rewritten) =
131- rewrite_inner_subqueries ( cur_input, e, config) ?;
132- cur_input = plan_after;
133- let new_name = rewritten. schema_name ( ) . to_string ( ) ;
153+ let mut cur_input = Arc :: unwrap_or_clone ( aggregate. input ) ;
154+ let mut new_group_exprs = Vec :: with_capacity ( aggregate. group_expr . len ( ) ) ;
155+ for expr in aggregate. group_expr {
156+ if has_subquery ( & expr) {
157+ let ( next_input, rewritten_expr) = rewrite_inner_subqueries ( cur_input, expr, config) ?;
158+ cur_input = next_input;
159+ new_group_exprs. push ( rewritten_expr) ;
160+ } else {
161+ new_group_exprs. push ( expr) ;
162+ }
163+ }
164+ let mut new_aggr_exprs = Vec :: with_capacity ( aggregate. aggr_expr . len ( ) ) ;
165+ for expr in aggregate. aggr_expr {
166+ if has_subquery ( & expr) {
167+ let old_name = expr. schema_name ( ) . to_string ( ) ;
168+ let ( next_input, rewritten_expr) = rewrite_inner_subqueries ( cur_input, expr, config) ?;
169+ cur_input = next_input;
170+ let new_name = rewritten_expr. schema_name ( ) . to_string ( ) ;
134171 if new_name != old_name {
135- new_exprs . push ( rewritten . alias ( old_name) ) ;
172+ new_aggr_exprs . push ( rewritten_expr . alias ( old_name) ) ;
136173 } else {
137- new_exprs . push ( rewritten ) ;
174+ new_aggr_exprs . push ( rewritten_expr ) ;
138175 }
176+ } else {
177+ new_aggr_exprs. push ( expr) ;
178+ }
179+ }
180+
181+ let new_plan = LogicalPlanBuilder :: from ( cur_input)
182+ . aggregate ( new_group_exprs, new_aggr_exprs) ?
183+ . build ( ) ?;
184+ return Ok ( Transformed :: yes ( new_plan) ) ;
185+ }
186+
187+ // Handle Projection nodes with subqueries in expressions
188+ if let LogicalPlan :: Projection ( proj) = plan {
189+ // Only proceed if any projection expression contains a subquery
190+ if !proj. expr . iter ( ) . any ( has_subquery) {
191+ return Ok ( Transformed :: no ( LogicalPlan :: Projection ( proj) ) ) ;
192+ }
193+
194+ let mut cur_input = Arc :: unwrap_or_clone ( proj. input ) ;
195+ let mut new_exprs = Vec :: with_capacity ( proj. expr . len ( ) ) ;
196+ for e in proj. expr {
197+ let old_name = e. schema_name ( ) . to_string ( ) ;
198+ let ( plan_after, rewritten) =
199+ rewrite_inner_subqueries ( cur_input, e, config) ?;
200+ cur_input = plan_after;
201+ let new_name = rewritten. schema_name ( ) . to_string ( ) ;
202+ if new_name != old_name {
203+ new_exprs. push ( rewritten. alias ( old_name) ) ;
204+ } else {
205+ new_exprs. push ( rewritten) ;
139206 }
140- let new_proj = Projection :: try_new ( new_exprs, Arc :: new ( cur_input) ) ?;
141- Ok ( Transformed :: yes ( LogicalPlan :: Projection ( new_proj) ) )
142207 }
143- other => Ok ( Transformed :: no ( other) ) ,
208+ let new_proj = Projection :: try_new ( new_exprs, Arc :: new ( cur_input) ) ?;
209+ return Ok ( Transformed :: yes ( LogicalPlan :: Projection ( new_proj) ) ) ;
144210 }
211+
212+ // Other plans unchanged
213+ Ok ( Transformed :: no ( plan) )
145214 }
146215
147216 fn name ( & self ) -> & str {
@@ -477,6 +546,42 @@ mod tests {
477546 ) )
478547 }
479548
549+ /// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate
550+ #[ test]
551+ fn aggregate_case_in_subquery ( ) -> Result < ( ) > {
552+ let table_scan = test_table_scan_with_name ( "distinct_source" ) ?;
553+ use datafusion_functions_aggregate:: expr_fn:: max as agg_max;
554+ use datafusion_expr:: expr_fn:: when;
555+
556+ let agg_b: Expr = agg_max ( col ( "distinct_source.b" ) ) ;
557+ let subq = LogicalPlanBuilder :: from ( table_scan. clone ( ) )
558+ . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ agg_b] ) ?
559+ . project ( vec ! [ col( "max(distinct_source.b)" ) ] ) ?
560+ . build ( ) ?;
561+
562+ let case_expr = when ( in_subquery ( col ( "distinct_source.b" ) , Arc :: new ( subq) ) , lit ( 1 ) )
563+ . otherwise ( lit ( 0 ) ) ?;
564+
565+ let plan = LogicalPlanBuilder :: from ( table_scan)
566+ . aggregate (
567+ vec ! [ col( "distinct_source.a" ) . alias( "primary_key" ) ] ,
568+ vec ! [
569+ agg_max( case_expr) . alias( "is_in_most_recent_task" ) ,
570+ agg_max( col( "distinct_source.c" ) ) . alias( "max_timestamp" ) ,
571+ ] ,
572+ ) ?
573+ . build ( ) ?;
574+
575+ use crate :: { OptimizerContext , OptimizerRule } ;
576+ let optimized = DecorrelatePredicateSubquery :: new ( )
577+ . rewrite ( plan, & OptimizerContext :: new ( ) ) ?
578+ . data ;
579+ let lp = optimized. display_indent ( ) . to_string ( ) ;
580+ assert ! ( lp. contains( "Aggregate:" ) ) ;
581+ assert ! ( lp. contains( "Left" ) ) ;
582+ Ok ( ( ) )
583+ }
584+
480585 /// Test for several IN subquery expressions
481586 #[ test]
482587 fn in_subquery_multiple ( ) -> Result < ( ) > {
0 commit comments