@@ -30,7 +30,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3030use datafusion_common:: { internal_err, plan_err, Column , Result } ;
3131use datafusion_expr:: expr:: { Exists , InSubquery } ;
3232use datafusion_expr:: expr_rewriter:: create_col_from_scalar_expr;
33- use datafusion_expr:: logical_plan:: { JoinType , Subquery } ;
33+ use datafusion_expr:: logical_plan:: { JoinType , Projection , Subquery } ;
3434use datafusion_expr:: utils:: { conjunction, split_conjunction_owned} ;
3535use datafusion_expr:: {
3636 exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr , Expr , Filter ,
@@ -66,54 +66,82 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
6666 } ) ?
6767 . data ;
6868
69- let LogicalPlan :: Filter ( filter) = plan else {
70- return Ok ( Transformed :: no ( plan) ) ;
71- } ;
72-
73- if !has_subquery ( & filter. predicate ) {
74- return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
75- }
69+ match plan {
70+ LogicalPlan :: Filter ( filter) => {
71+ if !has_subquery ( & filter. predicate ) {
72+ return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
73+ }
7674
77- let ( with_subqueries, mut other_exprs) : ( Vec < _ > , Vec < _ > ) =
78- split_conjunction_owned ( filter. predicate )
79- . into_iter ( )
80- . partition ( has_subquery) ;
75+ let ( with_subqueries, mut other_exprs) : ( Vec < _ > , Vec < _ > ) =
76+ split_conjunction_owned ( filter. predicate )
77+ . into_iter ( )
78+ . partition ( has_subquery) ;
8179
82- if with_subqueries. is_empty ( ) {
83- return internal_err ! (
84- "can not find expected subqueries in DecorrelatePredicateSubquery"
85- ) ;
86- }
80+ if with_subqueries. is_empty ( ) {
81+ return internal_err ! (
82+ "can not find expected subqueries in DecorrelatePredicateSubquery"
83+ ) ;
84+ }
8785
88- // iterate through all exists clauses in predicate, turning each into a join
89- let mut cur_input = Arc :: unwrap_or_clone ( filter. input ) ;
90- for subquery_expr in with_subqueries {
91- match extract_subquery_info ( subquery_expr) {
92- // The subquery expression is at the top level of the filter
93- SubqueryPredicate :: Top ( subquery) => {
94- match build_join_top ( & subquery, & cur_input, config. alias_generator ( ) ) ?
95- {
96- Some ( plan) => cur_input = plan,
97- // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
98- None => other_exprs. push ( subquery. expr ( ) ) ,
86+ // iterate through all exists clauses in predicate, turning each into a join
87+ let mut cur_input = Arc :: unwrap_or_clone ( filter. input ) ;
88+ for subquery_expr in with_subqueries {
89+ match extract_subquery_info ( subquery_expr) {
90+ // The subquery expression is at the top level of the filter
91+ SubqueryPredicate :: Top ( subquery) => {
92+ match build_join_top (
93+ & subquery,
94+ & cur_input,
95+ config. alias_generator ( ) ,
96+ ) ? {
97+ Some ( plan) => cur_input = plan,
98+ // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
99+ None => other_exprs. push ( subquery. expr ( ) ) ,
100+ }
101+ }
102+ // The subquery expression is embedded within another expression
103+ SubqueryPredicate :: Embedded ( expr) => {
104+ let ( plan, expr_without_subqueries) =
105+ rewrite_inner_subqueries ( cur_input, expr, config) ?;
106+ cur_input = plan;
107+ other_exprs. push ( expr_without_subqueries) ;
108+ }
99109 }
100110 }
101- // The subquery expression is embedded within another expression
102- SubqueryPredicate :: Embedded ( expr) => {
103- let ( plan , expr_without_subqueries ) =
104- rewrite_inner_subqueries ( cur_input , expr, config ) ? ;
105- cur_input = plan ;
106- other_exprs . push ( expr_without_subqueries ) ;
111+
112+ let expr = conjunction ( other_exprs ) ;
113+ let mut new_plan = cur_input ;
114+ if let Some ( expr ) = expr {
115+ let new_filter = Filter :: try_new ( expr , Arc :: new ( new_plan ) ) ? ;
116+ new_plan = LogicalPlan :: Filter ( new_filter ) ;
107117 }
118+ Ok ( Transformed :: yes ( new_plan) )
108119 }
109- }
120+ LogicalPlan :: Projection ( proj) => {
121+ // Only proceed if any projection expression contains a subquery
122+ if !proj. expr . iter ( ) . any ( has_subquery) {
123+ return Ok ( Transformed :: no ( LogicalPlan :: Projection ( proj) ) ) ;
124+ }
110125
111- let expr = conjunction ( other_exprs) ;
112- if let Some ( expr) = expr {
113- let new_filter = Filter :: try_new ( expr, Arc :: new ( cur_input) ) ?;
114- cur_input = LogicalPlan :: Filter ( new_filter) ;
126+ let mut cur_input = Arc :: unwrap_or_clone ( proj. input ) ;
127+ let mut new_exprs = Vec :: with_capacity ( proj. expr . len ( ) ) ;
128+ for e in proj. expr {
129+ let old_name = e. schema_name ( ) . to_string ( ) ;
130+ let ( plan_after, rewritten) =
131+ rewrite_inner_subqueries ( cur_input, e, config) ?;
132+ cur_input = plan_after;
133+ let new_name = rewritten. schema_name ( ) . to_string ( ) ;
134+ if new_name != old_name {
135+ new_exprs. push ( rewritten. alias ( old_name) ) ;
136+ } else {
137+ new_exprs. push ( rewritten) ;
138+ }
139+ }
140+ let new_proj = Projection :: try_new ( new_exprs, Arc :: new ( cur_input) ) ?;
141+ Ok ( Transformed :: yes ( LogicalPlan :: Projection ( new_proj) ) )
142+ }
143+ other => Ok ( Transformed :: no ( other) ) ,
115144 }
116- Ok ( Transformed :: yes ( cur_input) )
117145 }
118146
119147 fn name ( & self ) -> & str {
@@ -529,6 +557,31 @@ mod tests {
529557 assert_optimized_plan_equal ( plan, expected)
530558 }
531559
560+ /// Projection IN (subquery) should be decorrelated via LeftMark join in Projection
561+ #[ test]
562+ fn projection_in_subquery_simple ( ) -> Result < ( ) > {
563+ // Build outer values t(a) = (1),(2)
564+ let outer = LogicalPlanBuilder :: values ( vec ! [ vec![ lit( 1_i32 ) ] , vec![ lit( 2_i32 ) ] ] ) ?
565+ . project ( vec ! [ col( "column1" ) . alias( "a" ) ] ) ?
566+ . build ( ) ?;
567+
568+ // Build subquery u(a) = (2)
569+ let sub = Arc :: new (
570+ LogicalPlanBuilder :: values ( vec ! [ vec![ lit( 2_i32 ) ] ] ) ?
571+ . project ( vec ! [ col( "column1" ) . alias( "ua" ) ] ) ?
572+ . build ( ) ?,
573+ ) ;
574+
575+ let plan = LogicalPlanBuilder :: from ( outer)
576+ . project ( vec ! [ col( "a" ) , in_subquery( col( "a" ) , sub) . alias( "flag" ) ] ) ?
577+ . build ( ) ?;
578+
579+ // We expect a LeftMark join inserted and the projection keeps columns
580+ let expected = "Projection: a, __correlated_sq_1.mark AS flag [a:Int32;N, flag:Boolean]\n LeftMark Join: Filter: a = __correlated_sq_1.ua [a:Int32;N, mark:Boolean]\n Projection: column1 AS a [a:Int32;N]\n Values: (Int32(1)), (Int32(2)) [column1:Int32;N]\n SubqueryAlias: __correlated_sq_1 [ua:Int32;N]\n Projection: column1 AS ua [ua:Int32;N]\n Values: (Int32(2)) [column1:Int32;N]" ;
581+
582+ assert_optimized_plan_equal ( plan, expected)
583+ }
584+
532585 /// Test multiple correlated subqueries
533586 /// See subqueries.rs where_in_multiple()
534587 #[ test]
0 commit comments