@@ -285,81 +285,30 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
285285 Ok ( is_evaluate)
286286}
287287
288- #[ derive( Clone , Copy ) ]
289- struct JoinInputShape < ' a > {
290- base_plan : & ' a LogicalPlan ,
291- is_derived_relation : bool ,
292- }
293-
294- fn classify_join_input ( plan : & LogicalPlan ) -> JoinInputShape < ' _ > {
288+ fn classify_join_input ( plan : & LogicalPlan ) -> ( bool , bool ) {
295289 match plan {
296290 LogicalPlan :: SubqueryAlias ( subquery_alias) => {
297- let JoinInputShape { base_plan , .. } =
291+ let ( is_scalar_aggregate , _ ) =
298292 classify_join_input ( subquery_alias. input . as_ref ( ) ) ;
299- JoinInputShape {
300- base_plan,
301- is_derived_relation : true ,
302- }
293+ ( is_scalar_aggregate, true )
303294 }
304295 LogicalPlan :: Projection ( projection) => {
305- let shape = classify_join_input ( projection. input . as_ref ( ) ) ;
306- JoinInputShape {
307- is_derived_relation : shape. is_derived_relation ,
308- ..shape
309- }
296+ classify_join_input ( projection. input . as_ref ( ) )
310297 }
311- _ => JoinInputShape {
312- base_plan : plan,
313- is_derived_relation : false ,
314- } ,
298+ LogicalPlan :: Aggregate ( aggregate) => ( aggregate. group_expr . is_empty ( ) , false ) ,
299+ _ => ( false , false ) ,
315300 }
316301}
317302
318- fn is_scalar_aggregate_subquery ( shape : JoinInputShape < ' _ > ) -> bool {
319- matches ! (
320- shape. base_plan,
321- LogicalPlan :: Aggregate ( aggregate) if aggregate. group_expr. is_empty( )
322- )
323- }
324-
325303fn is_scalar_subquery_cross_join ( join : & Join ) -> bool {
326- let left_shape = classify_join_input ( join. left . as_ref ( ) ) ;
327- let right_shape = classify_join_input ( join. right . as_ref ( ) ) ;
304+ let ( left_scalar_aggregate, left_is_derived_relation) =
305+ classify_join_input ( join. left . as_ref ( ) ) ;
306+ let ( right_scalar_aggregate, right_is_derived_relation) =
307+ classify_join_input ( join. right . as_ref ( ) ) ;
328308 join. on . is_empty ( )
329309 && join. filter . is_none ( )
330- && ( ( is_scalar_aggregate_subquery ( left_shape) && right_shape. is_derived_relation )
331- || ( is_scalar_aggregate_subquery ( right_shape)
332- && left_shape. is_derived_relation ) )
333- }
334-
335- // Keep post-join filters above certain scalar-subquery cross joins to preserve
336- // behavior for the window-over-scalar-subquery regression shape.
337- fn should_keep_filter_above_scalar_subquery_cross_join (
338- mut checker : ColumnChecker < ' _ > ,
339- predicate : & Expr ,
340- ) -> bool {
341- !checker. is_left_only ( predicate) && !checker. is_right_only ( predicate)
342- }
343-
344- enum PredicateDestination {
345- Left ,
346- Right ,
347- Keep ,
348- }
349-
350- fn classify_predicate_destination (
351- checker : & mut ColumnChecker < ' _ > ,
352- predicate : & Expr ,
353- allow_left : bool ,
354- allow_right : bool ,
355- ) -> PredicateDestination {
356- if allow_left && checker. is_left_only ( predicate) {
357- PredicateDestination :: Left
358- } else if allow_right && checker. is_right_only ( predicate) {
359- PredicateDestination :: Right
360- } else {
361- PredicateDestination :: Keep
362- }
310+ && ( ( left_scalar_aggregate && right_is_derived_relation)
311+ || ( right_scalar_aggregate && left_is_derived_relation) )
363312}
364313
365314/// examine OR clause to see if any useful clauses can be extracted and push down.
@@ -506,41 +455,28 @@ fn push_down_all_join(
506455 let keep_mixed_scalar_subquery_filters =
507456 is_inner_join && is_scalar_subquery_cross_join ( & join) ;
508457 for predicate in predicates {
509- match classify_predicate_destination (
510- & mut checker,
511- & predicate,
512- left_preserved,
513- right_preserved,
514- ) {
515- PredicateDestination :: Left => left_push. push ( predicate) ,
516- PredicateDestination :: Right => right_push. push ( predicate) ,
517- PredicateDestination :: Keep => {
518- let should_keep_above_join = keep_mixed_scalar_subquery_filters
519- && should_keep_filter_above_scalar_subquery_cross_join (
520- ColumnChecker :: new ( left_schema, right_schema) ,
521- & predicate,
522- ) ;
523-
524- if is_inner_join
525- && !should_keep_above_join
526- && can_evaluate_as_join_condition ( & predicate) ?
527- {
528- // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
529- // and convert to the join on condition
530- join_conditions. push ( predicate) ;
531- } else {
532- keep_predicates. push ( predicate) ;
533- }
534- }
458+ if left_preserved && checker. is_left_only ( & predicate) {
459+ left_push. push ( predicate) ;
460+ } else if right_preserved && checker. is_right_only ( & predicate) {
461+ right_push. push ( predicate) ;
462+ } else if is_inner_join
463+ && !keep_mixed_scalar_subquery_filters
464+ && can_evaluate_as_join_condition ( & predicate) ?
465+ {
466+ // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
467+ // and convert to the join on condition
468+ join_conditions. push ( predicate) ;
469+ } else {
470+ keep_predicates. push ( predicate) ;
535471 }
536472 }
537473
538474 // Push predicates inferred from the join expression
539475 for predicate in inferred_join_predicates {
540- match classify_predicate_destination ( & mut checker , & predicate, true , true ) {
541- PredicateDestination :: Left => left_push. push ( predicate) ,
542- PredicateDestination :: Right => right_push . push ( predicate) ,
543- PredicateDestination :: Keep => { }
476+ if checker . is_left_only ( & predicate) {
477+ left_push. push ( predicate) ;
478+ } else if checker . is_right_only ( & predicate) {
479+ right_push . push ( predicate ) ;
544480 }
545481 }
546482
@@ -549,15 +485,12 @@ fn push_down_all_join(
549485
550486 if !on_filter. is_empty ( ) {
551487 for on in on_filter {
552- match classify_predicate_destination (
553- & mut checker,
554- & on,
555- on_left_preserved,
556- on_right_preserved,
557- ) {
558- PredicateDestination :: Left => left_push. push ( on) ,
559- PredicateDestination :: Right => right_push. push ( on) ,
560- PredicateDestination :: Keep => on_filter_join_conditions. push ( on) ,
488+ if on_left_preserved && checker. is_left_only ( & on) {
489+ left_push. push ( on)
490+ } else if on_right_preserved && checker. is_right_only ( & on) {
491+ right_push. push ( on)
492+ } else {
493+ on_filter_join_conditions. push ( on)
561494 }
562495 }
563496 }
@@ -821,24 +754,26 @@ fn infer_join_predicates_impl<
821754 inferred_predicates : & mut InferredPredicates ,
822755) -> Result < ( ) > {
823756 for predicate in input_predicates {
824- let column_refs = predicate . column_refs ( ) ;
757+ let mut join_cols_to_replace = HashMap :: new ( ) ;
825758 let mut saw_non_replaceable_ref = false ;
826- let join_cols_to_replace = column_refs
827- . iter ( )
828- . filter_map ( |& col| {
829- let replacement = join_col_keys. iter ( ) . find_map ( |( l, r) | {
830- if ENABLE_LEFT_TO_RIGHT && col == * l {
831- Some ( ( col, * r) )
832- } else if ENABLE_RIGHT_TO_LEFT && col == * r {
833- Some ( ( col, * l) )
834- } else {
835- None
836- }
837- } ) ;
838- saw_non_replaceable_ref |= replacement. is_none ( ) ;
839- replacement
840- } )
841- . collect :: < HashMap < _ , _ > > ( ) ;
759+
760+ for & col in & predicate. column_refs ( ) {
761+ let replacement = join_col_keys. iter ( ) . find_map ( |( l, r) | {
762+ if ENABLE_LEFT_TO_RIGHT && col == * l {
763+ Some ( ( col, * r) )
764+ } else if ENABLE_RIGHT_TO_LEFT && col == * r {
765+ Some ( ( col, * l) )
766+ } else {
767+ None
768+ }
769+ } ) ;
770+
771+ if let Some ( ( source, target) ) = replacement {
772+ join_cols_to_replace. insert ( source, target) ;
773+ } else {
774+ saw_non_replaceable_ref = true ;
775+ }
776+ }
842777
843778 if join_cols_to_replace. is_empty ( )
844779 || ( !inferred_predicates. is_inner_join && saw_non_replaceable_ref)
0 commit comments