@@ -285,30 +285,53 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
285285 Ok ( is_evaluate)
286286}
287287
288- fn classify_join_input ( plan : & LogicalPlan ) -> ( bool , bool ) {
288+ fn strip_plan_wrappers ( plan : & LogicalPlan ) -> ( & LogicalPlan , bool ) {
289289 match plan {
290290 LogicalPlan :: SubqueryAlias ( subquery_alias) => {
291- let ( is_scalar_aggregate, _) =
292- classify_join_input ( subquery_alias. input . as_ref ( ) ) ;
293- ( is_scalar_aggregate, true )
291+ let ( plan, _) = strip_plan_wrappers ( subquery_alias. input . as_ref ( ) ) ;
292+ ( plan, true )
294293 }
295294 LogicalPlan :: Projection ( projection) => {
296- classify_join_input ( projection. input . as_ref ( ) )
295+ let ( plan, is_derived_relation) =
296+ strip_plan_wrappers ( projection. input . as_ref ( ) ) ;
297+ ( plan, is_derived_relation)
297298 }
298- LogicalPlan :: Aggregate ( aggregate) => ( aggregate. group_expr . is_empty ( ) , false ) ,
299- _ => ( false , false ) ,
299+ _ => ( plan, false ) ,
300300 }
301301}
302302
303+ fn is_scalar_aggregate_subquery ( plan : & LogicalPlan ) -> bool {
304+ matches ! (
305+ strip_plan_wrappers( plan) . 0 ,
306+ LogicalPlan :: Aggregate ( aggregate) if aggregate. group_expr. is_empty( )
307+ )
308+ }
309+
310+ fn is_derived_relation ( plan : & LogicalPlan ) -> bool {
311+ strip_plan_wrappers ( plan) . 1
312+ }
313+
303314fn is_scalar_subquery_cross_join ( join : & Join ) -> bool {
304- let ( left_scalar_aggregate, left_is_derived_relation) =
305- classify_join_input ( join. left . as_ref ( ) ) ;
306- let ( right_scalar_aggregate, right_is_derived_relation) =
307- classify_join_input ( join. right . as_ref ( ) ) ;
308315 join. on . is_empty ( )
309316 && join. filter . is_none ( )
310- && ( ( left_scalar_aggregate && right_is_derived_relation)
311- || ( right_scalar_aggregate && left_is_derived_relation) )
317+ && ( ( is_scalar_aggregate_subquery ( join. left . as_ref ( ) )
318+ && is_derived_relation ( join. right . as_ref ( ) ) )
319+ || ( is_scalar_aggregate_subquery ( join. right . as_ref ( ) )
320+ && is_derived_relation ( join. left . as_ref ( ) ) ) )
321+ }
322+
323+ // Keep post-join filters above certain scalar-subquery cross joins to preserve
324+ // behavior for the window-over-scalar-subquery regression shape.
325+ fn should_keep_filter_above_scalar_subquery_cross_join (
326+ join : & Join ,
327+ predicate : & Expr ,
328+ ) -> bool {
329+ if !is_scalar_subquery_cross_join ( join) {
330+ return false ;
331+ }
332+
333+ let mut checker = ColumnChecker :: new ( join. left . schema ( ) , join. right . schema ( ) ) ;
334+ !checker. is_left_only ( predicate) && !checker. is_right_only ( predicate)
312335}
313336
314337/// examine OR clause to see if any useful clauses can be extracted and push down.
@@ -452,15 +475,13 @@ fn push_down_all_join(
452475 let mut keep_predicates = vec ! [ ] ;
453476 let mut join_conditions = vec ! [ ] ;
454477 let mut checker = ColumnChecker :: new ( left_schema, right_schema) ;
455- let keep_mixed_scalar_subquery_filters =
456- is_inner_join && is_scalar_subquery_cross_join ( & join) ;
457478 for predicate in predicates {
458479 if left_preserved && checker. is_left_only ( & predicate) {
459480 left_push. push ( predicate) ;
460481 } else if right_preserved && checker. is_right_only ( & predicate) {
461482 right_push. push ( predicate) ;
462483 } else if is_inner_join
463- && !keep_mixed_scalar_subquery_filters
484+ && !should_keep_filter_above_scalar_subquery_cross_join ( & join , & predicate )
464485 && can_evaluate_as_join_condition ( & predicate) ?
465486 {
466487 // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
@@ -754,33 +775,36 @@ fn infer_join_predicates_impl<
754775 inferred_predicates : & mut InferredPredicates ,
755776) -> Result < ( ) > {
756777 for predicate in input_predicates {
757- let mut join_cols_to_replace = HashMap :: new ( ) ;
758- let mut saw_non_replaceable_ref = false ;
759-
760- for & col in & predicate. column_refs ( ) {
761- let replacement = join_col_keys. iter ( ) . find_map ( |( l, r) | {
762- if ENABLE_LEFT_TO_RIGHT && col == * l {
763- Some ( ( col, * r) )
764- } else if ENABLE_RIGHT_TO_LEFT && col == * r {
765- Some ( ( col, * l) )
766- } else {
767- None
768- }
769- } ) ;
778+ let column_refs = predicate. column_refs ( ) ;
779+ let join_col_replacements: Vec < _ > = column_refs
780+ . iter ( )
781+ . filter_map ( |& col| {
782+ join_col_keys. iter ( ) . find_map ( |( l, r) | {
783+ if ENABLE_LEFT_TO_RIGHT && col == * l {
784+ Some ( ( col, * r) )
785+ } else if ENABLE_RIGHT_TO_LEFT && col == * r {
786+ Some ( ( col, * l) )
787+ } else {
788+ None
789+ }
790+ } )
791+ } )
792+ . collect ( ) ;
770793
771- if let Some ( ( source, target) ) = replacement {
772- join_cols_to_replace. insert ( source, target) ;
773- } else {
774- saw_non_replaceable_ref = true ;
775- }
794+ if join_col_replacements. is_empty ( ) {
795+ continue ;
776796 }
777797
778- if join_cols_to_replace. is_empty ( )
779- || ( !inferred_predicates. is_inner_join && saw_non_replaceable_ref)
798+ // For non-inner joins, predicates that reference any non-replaceable
799+ // columns cannot be inferred on the other side. Skip the null-restriction
800+ // helper entirely in that common mixed-reference case.
801+ if !inferred_predicates. is_inner_join
802+ && join_col_replacements. len ( ) != column_refs. len ( )
780803 {
781804 continue ;
782805 }
783806
807+ let join_cols_to_replace = join_col_replacements. into_iter ( ) . collect ( ) ;
784808 inferred_predicates
785809 . try_build_predicate ( predicate. clone ( ) , & join_cols_to_replace) ?;
786810 }
@@ -1529,53 +1553,6 @@ mod tests {
15291553
15301554 use super :: * ;
15311555
1532- fn scalar_subquery_right_plan ( ) -> Result < LogicalPlan > {
1533- LogicalPlanBuilder :: from ( test_table_scan_with_name ( "test1" ) ?)
1534- . project ( vec ! [ col( "a" ) . alias( "acctbal" ) ] ) ?
1535- . aggregate (
1536- Vec :: < Expr > :: new ( ) ,
1537- vec ! [ avg( col( "acctbal" ) ) . alias( "avg_acctbal" ) ] ,
1538- ) ?
1539- . alias ( "__scalar_sq_1" ) ?
1540- . build ( )
1541- }
1542-
1543- fn row_number_window_expr ( ) -> Expr {
1544- Expr :: from ( WindowFunction :: new (
1545- WindowFunctionDefinition :: WindowUDF (
1546- datafusion_functions_window:: row_number:: row_number_udwf ( ) ,
1547- ) ,
1548- vec ! [ ] ,
1549- ) )
1550- . partition_by ( vec ! [ col( "s.nation" ) ] )
1551- . order_by ( vec ! [ col( "s.acctbal" ) . sort( false , true ) ] )
1552- . build ( )
1553- . unwrap ( )
1554- }
1555-
1556- fn window_over_scalar_subquery_cross_join_plan (
1557- with_project_wrapper : bool ,
1558- ) -> Result < LogicalPlan > {
1559- let left = {
1560- let builder = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
1561- . project ( vec ! [ col( "a" ) . alias( "nation" ) , col( "b" ) . alias( "acctbal" ) ] ) ?
1562- . alias ( "s" ) ?;
1563- let builder = if with_project_wrapper {
1564- builder. project ( vec ! [ col( "s.nation" ) , col( "s.acctbal" ) ] ) ?
1565- } else {
1566- builder
1567- } ;
1568- builder. build ( ) ?
1569- } ;
1570-
1571- LogicalPlanBuilder :: from ( left)
1572- . cross_join ( scalar_subquery_right_plan ( ) ?) ?
1573- . filter ( col ( "s.acctbal" ) . gt ( col ( "__scalar_sq_1.avg_acctbal" ) ) ) ?
1574- . project ( vec ! [ col( "s.nation" ) , col( "s.acctbal" ) ] ) ?
1575- . window ( vec ! [ row_number_window_expr( ) ] ) ?
1576- . build ( )
1577- }
1578-
15791556 fn observe ( _plan : & LogicalPlan , _rule : & dyn OptimizerRule ) { }
15801557
15811558 macro_rules! assert_optimized_plan_equal {
@@ -2497,7 +2474,36 @@ mod tests {
24972474
24982475 #[ test]
24992476 fn window_over_scalar_subquery_cross_join_keeps_filter_above_join ( ) -> Result < ( ) > {
2500- let plan = window_over_scalar_subquery_cross_join_plan ( false ) ?;
2477+ let left = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
2478+ . project ( vec ! [ col( "a" ) . alias( "nation" ) , col( "b" ) . alias( "acctbal" ) ] ) ?
2479+ . alias ( "s" ) ?
2480+ . build ( ) ?;
2481+ let right = LogicalPlanBuilder :: from ( test_table_scan_with_name ( "test1" ) ?)
2482+ . project ( vec ! [ col( "a" ) . alias( "acctbal" ) ] ) ?
2483+ . aggregate (
2484+ Vec :: < Expr > :: new ( ) ,
2485+ vec ! [ avg( col( "acctbal" ) ) . alias( "avg_acctbal" ) ] ,
2486+ ) ?
2487+ . alias ( "__scalar_sq_1" ) ?
2488+ . build ( ) ?;
2489+
2490+ let window = Expr :: from ( WindowFunction :: new (
2491+ WindowFunctionDefinition :: WindowUDF (
2492+ datafusion_functions_window:: row_number:: row_number_udwf ( ) ,
2493+ ) ,
2494+ vec ! [ ] ,
2495+ ) )
2496+ . partition_by ( vec ! [ col( "s.nation" ) ] )
2497+ . order_by ( vec ! [ col( "s.acctbal" ) . sort( false , true ) ] )
2498+ . build ( )
2499+ . unwrap ( ) ;
2500+
2501+ let plan = LogicalPlanBuilder :: from ( left)
2502+ . cross_join ( right) ?
2503+ . filter ( col ( "s.acctbal" ) . gt ( col ( "__scalar_sq_1.avg_acctbal" ) ) ) ?
2504+ . project ( vec ! [ col( "s.nation" ) , col( "s.acctbal" ) ] ) ?
2505+ . window ( vec ! [ window] ) ?
2506+ . build ( ) ?;
25012507
25022508 assert_optimized_plan_equal ! (
25032509 plan,
@@ -2520,7 +2526,37 @@ mod tests {
25202526 #[ test]
25212527 fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join ( )
25222528 -> Result < ( ) > {
2523- let plan = window_over_scalar_subquery_cross_join_plan ( true ) ?;
2529+ let left = LogicalPlanBuilder :: from ( test_table_scan ( ) ?)
2530+ . project ( vec ! [ col( "a" ) . alias( "nation" ) , col( "b" ) . alias( "acctbal" ) ] ) ?
2531+ . alias ( "s" ) ?
2532+ . project ( vec ! [ col( "s.nation" ) , col( "s.acctbal" ) ] ) ?
2533+ . build ( ) ?;
2534+ let right = LogicalPlanBuilder :: from ( test_table_scan_with_name ( "test1" ) ?)
2535+ . project ( vec ! [ col( "a" ) . alias( "acctbal" ) ] ) ?
2536+ . aggregate (
2537+ Vec :: < Expr > :: new ( ) ,
2538+ vec ! [ avg( col( "acctbal" ) ) . alias( "avg_acctbal" ) ] ,
2539+ ) ?
2540+ . alias ( "__scalar_sq_1" ) ?
2541+ . build ( ) ?;
2542+
2543+ let window = Expr :: from ( WindowFunction :: new (
2544+ WindowFunctionDefinition :: WindowUDF (
2545+ datafusion_functions_window:: row_number:: row_number_udwf ( ) ,
2546+ ) ,
2547+ vec ! [ ] ,
2548+ ) )
2549+ . partition_by ( vec ! [ col( "s.nation" ) ] )
2550+ . order_by ( vec ! [ col( "s.acctbal" ) . sort( false , true ) ] )
2551+ . build ( )
2552+ . unwrap ( ) ;
2553+
2554+ let plan = LogicalPlanBuilder :: from ( left)
2555+ . cross_join ( right) ?
2556+ . filter ( col ( "s.acctbal" ) . gt ( col ( "__scalar_sq_1.avg_acctbal" ) ) ) ?
2557+ . project ( vec ! [ col( "s.nation" ) , col( "s.acctbal" ) ] ) ?
2558+ . window ( vec ! [ window] ) ?
2559+ . build ( ) ?;
25242560
25252561 assert_optimized_plan_equal ! (
25262562 plan,
0 commit comments