@@ -37,8 +37,8 @@ use datafusion_expr::logical_plan::{
3737} ;
3838use datafusion_expr:: utils:: { conjunction, expr_to_columns, split_conjunction_owned} ;
3939use datafusion_expr:: {
40- BinaryExpr , Expr , Filter , LogicalPlan , LogicalPlanBuilder , Operator , exists ,
41- in_subquery, lit, not, not_exists, not_in_subquery,
40+ Aggregate , BinaryExpr , Expr , Filter , LogicalPlan , LogicalPlanBuilder , Operator ,
41+ exists , in_subquery, lit, not, not_exists, not_in_subquery,
4242} ;
4343
4444use log:: debug;
@@ -198,10 +198,9 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
198198 }
199199 }
200200
201- let new_plan = LogicalPlanBuilder :: from ( cur_input)
202- . aggregate ( new_group_exprs, new_aggr_exprs) ?
203- . build ( ) ?;
204- return Ok ( Transformed :: yes ( new_plan) ) ;
201+ let new_agg =
202+ Aggregate :: try_new ( Arc :: new ( cur_input) , new_group_exprs, new_aggr_exprs) ?;
203+ return Ok ( Transformed :: yes ( LogicalPlan :: Aggregate ( new_agg) ) ) ;
205204 }
206205
207206 // Handle Projection nodes with subqueries in expressions
@@ -678,45 +677,6 @@ mod tests {
678677 ) )
679678 }
680679
681- /// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate
682- #[ test]
683- fn aggregate_case_in_subquery ( ) -> Result < ( ) > {
684- let table_scan = test_table_scan_with_name ( "distinct_source" ) ?;
685- use datafusion_expr:: expr_fn:: when;
686- use datafusion_functions_aggregate:: expr_fn:: max as agg_max;
687-
688- let agg_b: Expr = agg_max ( col ( "distinct_source.b" ) ) ;
689- let subq = LogicalPlanBuilder :: from ( table_scan. clone ( ) )
690- . aggregate ( Vec :: < Expr > :: new ( ) , vec ! [ agg_b] ) ?
691- . project ( vec ! [ col( "max(distinct_source.b)" ) ] ) ?
692- . build ( ) ?;
693-
694- let case_expr = when (
695- in_subquery ( col ( "distinct_source.b" ) , Arc :: new ( subq) ) ,
696- lit ( 1 ) ,
697- )
698- . otherwise ( lit ( 0 ) ) ?;
699-
700- let plan = LogicalPlanBuilder :: from ( table_scan)
701- . aggregate (
702- vec ! [ col( "distinct_source.a" ) . alias( "primary_key" ) ] ,
703- vec ! [
704- agg_max( case_expr) . alias( "is_in_most_recent_task" ) ,
705- agg_max( col( "distinct_source.c" ) ) . alias( "max_timestamp" ) ,
706- ] ,
707- ) ?
708- . build ( ) ?;
709-
710- use crate :: { OptimizerContext , OptimizerRule } ;
711- let optimized = DecorrelatePredicateSubquery :: new ( )
712- . rewrite ( plan, & OptimizerContext :: new ( ) ) ?
713- . data ;
714- let lp = optimized. display_indent ( ) . to_string ( ) ;
715- assert ! ( lp. contains( "Aggregate:" ) ) ;
716- assert ! ( lp. contains( "Left" ) ) ;
717- Ok ( ( ) )
718- }
719-
720680 /// Test for several IN subquery expressions
721681 #[ test]
722682 fn in_subquery_multiple ( ) -> Result < ( ) > {
@@ -834,9 +794,10 @@ mod tests {
834794 LeftMark Join: Filter: a = __correlated_sq_1.ua [a:Int32;N, mark:Boolean]
835795 Projection: column1 AS a [a:Int32;N]
836796 Values: (Int32(1)), (Int32(2)) [column1:Int32;N]
837- SubqueryAlias: __correlated_sq_1 [ua:Int32;N]
838- Projection: column1 AS ua [ua:Int32;N]
839- Values: (Int32(2)) [column1:Int32;N]
797+ Projection: __correlated_sq_1.ua [ua:Int32;N]
798+ SubqueryAlias: __correlated_sq_1 [ua:Int32;N]
799+ Projection: column1 AS ua [ua:Int32;N]
800+ Values: (Int32(2)) [column1:Int32;N]
840801 "
841802 )
842803 }
@@ -1924,14 +1885,13 @@ mod tests {
19241885 plan,
19251886 @r"
19261887 Projection: customer.c_custkey [c_custkey:Int64]
1927- Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1928- Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1929- LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1930- TableScan: customer [c_custkey:Int64, c_name:Utf8]
1931- SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1932- Projection: orders.o_custkey [o_custkey:Int64]
1933- Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1934- TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1888+ Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1889+ LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1890+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
1891+ SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1892+ Projection: orders.o_custkey [o_custkey:Int64]
1893+ Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1894+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
19351895 "
19361896 )
19371897 }
0 commit comments