@@ -285,27 +285,26 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
285285 Ok ( is_evaluate)
286286}
287287
288- fn is_scalar_subquery_cross_join ( join : & Join ) -> bool {
289- fn is_scalar_aggregate_or_derived_relation ( plan : & LogicalPlan ) -> ( bool , bool ) {
290- match plan {
291- LogicalPlan :: SubqueryAlias ( subquery_alias) => {
292- let ( is_scalar_aggregate, _) = is_scalar_aggregate_or_derived_relation (
293- subquery_alias. input . as_ref ( ) ,
294- ) ;
295- ( is_scalar_aggregate, true )
296- }
297- LogicalPlan :: Projection ( projection) => {
298- is_scalar_aggregate_or_derived_relation ( projection. input . as_ref ( ) )
299- }
300- LogicalPlan :: Aggregate ( aggregate) => ( aggregate. group_expr . is_empty ( ) , false ) ,
301- _ => ( false , false ) ,
288+ fn classify_join_input ( plan : & LogicalPlan ) -> ( bool , bool ) {
289+ match plan {
290+ LogicalPlan :: SubqueryAlias ( subquery_alias) => {
291+ let ( is_scalar_aggregate, _) =
292+ classify_join_input ( subquery_alias. input . as_ref ( ) ) ;
293+ ( is_scalar_aggregate, true )
294+ }
295+ LogicalPlan :: Projection ( projection) => {
296+ classify_join_input ( projection. input . as_ref ( ) )
302297 }
298+ LogicalPlan :: Aggregate ( aggregate) => ( aggregate. group_expr . is_empty ( ) , false ) ,
299+ _ => ( false , false ) ,
303300 }
301+ }
304302
303+ fn is_scalar_subquery_cross_join ( join : & Join ) -> bool {
305304 let ( left_scalar_aggregate, left_is_derived_relation) =
306- is_scalar_aggregate_or_derived_relation ( join. left . as_ref ( ) ) ;
305+ classify_join_input ( join. left . as_ref ( ) ) ;
307306 let ( right_scalar_aggregate, right_is_derived_relation) =
308- is_scalar_aggregate_or_derived_relation ( join. right . as_ref ( ) ) ;
307+ classify_join_input ( join. right . as_ref ( ) ) ;
309308 join. on . is_empty ( )
310309 && join. filter . is_none ( )
311310 && ( ( left_scalar_aggregate && right_is_derived_relation)
@@ -456,12 +455,9 @@ fn push_down_all_join(
456455 let keep_mixed_scalar_subquery_filters =
457456 is_inner_join && is_scalar_subquery_cross_join ( & join) ;
458457 for predicate in predicates {
459- let left_only = left_preserved && checker. is_left_only ( & predicate) ;
460- let right_only =
461- !left_only && right_preserved && checker. is_right_only ( & predicate) ;
462- if left_only {
458+ if left_preserved && checker. is_left_only ( & predicate) {
463459 left_push. push ( predicate) ;
464- } else if right_only {
460+ } else if right_preserved && checker . is_right_only ( & predicate ) {
465461 right_push. push ( predicate) ;
466462 } else if is_inner_join
467463 && !keep_mixed_scalar_subquery_filters
@@ -487,63 +483,43 @@ fn push_down_all_join(
487483 let mut on_filter_join_conditions = vec ! [ ] ;
488484 let ( on_left_preserved, on_right_preserved) = on_lr_is_preserved ( join. join_type ) ;
489485
490- for on in on_filter {
491- if on_left_preserved && checker. is_left_only ( & on) {
492- left_push. push ( on)
493- } else if on_right_preserved && checker. is_right_only ( & on) {
494- right_push. push ( on)
495- } else {
496- on_filter_join_conditions. push ( on)
486+ if !on_filter. is_empty ( ) {
487+ for on in on_filter {
488+ if on_left_preserved && checker. is_left_only ( & on) {
489+ left_push. push ( on)
490+ } else if on_right_preserved && checker. is_right_only ( & on) {
491+ right_push. push ( on)
492+ } else {
493+ on_filter_join_conditions. push ( on)
494+ }
497495 }
498496 }
499497
500498 // Extract from OR clause, generate new predicates for both side of join if possible.
501499 // We only track the unpushable predicates above.
502- let extend_or_clauses =
503- |target : & mut Vec < Expr > , filters : & [ Expr ] , schema : & DFSchema , preserved| {
504- if preserved {
505- target. extend ( extract_or_clauses_for_join ( filters, schema) ) ;
506- }
507- } ;
508- extend_or_clauses (
509- & mut left_push,
510- & keep_predicates,
511- left_schema,
512- left_preserved,
513- ) ;
514- extend_or_clauses (
515- & mut left_push,
516- & join_conditions,
517- left_schema,
518- left_preserved,
519- ) ;
520- extend_or_clauses (
521- & mut right_push,
522- & keep_predicates,
523- right_schema,
524- right_preserved,
525- ) ;
526- extend_or_clauses (
527- & mut right_push,
528- & join_conditions,
529- right_schema,
530- right_preserved,
531- ) ;
500+ if left_preserved {
501+ left_push. extend ( extract_or_clauses_for_join ( & keep_predicates, left_schema) ) ;
502+ left_push. extend ( extract_or_clauses_for_join ( & join_conditions, left_schema) ) ;
503+ }
504+ if right_preserved {
505+ right_push. extend ( extract_or_clauses_for_join ( & keep_predicates, right_schema) ) ;
506+ right_push. extend ( extract_or_clauses_for_join ( & join_conditions, right_schema) ) ;
507+ }
532508
533509 // For predicates from join filter, we should check with if a join side is preserved
534510 // in term of join filtering.
535- extend_or_clauses (
536- & mut left_push,
537- & on_filter_join_conditions,
538- left_schema,
539- on_left_preserved ,
540- ) ;
541- extend_or_clauses (
542- & mut right_push,
543- & on_filter_join_conditions,
544- right_schema,
545- on_right_preserved ,
546- ) ;
511+ if on_left_preserved {
512+ left_push. extend ( extract_or_clauses_for_join (
513+ & on_filter_join_conditions,
514+ left_schema,
515+ ) ) ;
516+ }
517+ if on_right_preserved {
518+ right_push. extend ( extract_or_clauses_for_join (
519+ & on_filter_join_conditions,
520+ right_schema,
521+ ) ) ;
522+ }
547523
548524 if let Some ( predicate) = conjunction ( left_push) {
549525 join. left = Arc :: new ( LogicalPlan :: Filter ( Filter :: try_new ( predicate, join. left ) ?) ) ;
@@ -778,11 +754,10 @@ fn infer_join_predicates_impl<
778754 inferred_predicates : & mut InferredPredicates ,
779755) -> Result < ( ) > {
780756 for predicate in input_predicates {
781- let column_refs = predicate. column_refs ( ) ;
782757 let mut join_cols_to_replace = HashMap :: new ( ) ;
783758 let mut saw_non_replaceable_ref = false ;
784759
785- for & col in & column_refs {
760+ for & col in & predicate . column_refs ( ) {
786761 let replacement = join_col_keys. iter ( ) . find_map ( |( l, r) | {
787762 if ENABLE_LEFT_TO_RIGHT && col == * l {
788763 Some ( ( col, * r) )
0 commit comments