@@ -23,7 +23,7 @@ use datafusion_expr::{Expr, Filter, Operator};
2323
2424use crate :: optimizer:: ApplyOrder ;
2525use datafusion_common:: tree_node:: Transformed ;
26- use datafusion_expr:: expr:: { BinaryExpr , Cast , TryCast } ;
26+ use datafusion_expr:: expr:: { BinaryExpr , Cast , InList , TryCast } ;
2727use std:: sync:: Arc ;
2828
2929///
@@ -298,6 +298,23 @@ fn extract_non_nullable_columns(
298298 right_schema,
299299 false ,
300300 ) ,
301+ // IN list and BETWEEN are null-rejecting on the input expression:
302+ // if the input column is NULL, the result is NULL (filtered out),
303+ // regardless of whether the list/range contains NULLs.
304+ Expr :: InList ( InList { expr, .. } ) => extract_non_nullable_columns (
305+ expr,
306+ non_nullable_cols,
307+ left_schema,
308+ right_schema,
309+ false ,
310+ ) ,
311+ Expr :: Between ( between) => extract_non_nullable_columns (
312+ & between. expr ,
313+ non_nullable_cols,
314+ left_schema,
315+ right_schema,
316+ false ,
317+ ) ,
301318 _ => { }
302319 }
303320}
@@ -309,6 +326,7 @@ mod tests {
309326 use crate :: assert_optimized_plan_eq_snapshot;
310327 use crate :: test:: * ;
311328 use arrow:: datatypes:: DataType ;
329+ use datafusion_common:: ScalarValue ;
312330 use datafusion_expr:: {
313331 Operator :: { And , Or } ,
314332 binary_expr, cast, col, lit,
@@ -436,6 +454,221 @@ mod tests {
436454 " )
437455 }
438456
457+ #[ test]
458+ fn eliminate_left_with_in_list ( ) -> Result < ( ) > {
459+ let t1 = test_table_scan_with_name ( "t1" ) ?;
460+ let t2 = test_table_scan_with_name ( "t2" ) ?;
461+
462+ // t2.b IN (1, 2, 3) rejects nulls — if t2.b is NULL the IN returns
463+ // NULL which is filtered out. So Left Join should become Inner Join.
464+ let plan = LogicalPlanBuilder :: from ( t1)
465+ . join (
466+ t2,
467+ JoinType :: Left ,
468+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
469+ None ,
470+ ) ?
471+ . filter ( col ( "t2.b" ) . in_list ( vec ! [ lit( 1u32 ) , lit( 2u32 ) , lit( 3u32 ) ] , false ) ) ?
472+ . build ( ) ?;
473+
474+ assert_optimized_plan_equal ! ( plan, @r"
475+ Filter: t2.b IN ([UInt32(1), UInt32(2), UInt32(3)])
476+ Inner Join: t1.a = t2.a
477+ TableScan: t1
478+ TableScan: t2
479+ " )
480+ }
481+
482+ #[ test]
483+ fn eliminate_left_with_in_list_containing_null ( ) -> Result < ( ) > {
484+ let t1 = test_table_scan_with_name ( "t1" ) ?;
485+ let t2 = test_table_scan_with_name ( "t2" ) ?;
486+
487+ // IN list with NULL still rejects null input columns:
488+ // if t2.b is NULL, NULL IN (1, NULL) evaluates to NULL, which is filtered out
489+ let plan = LogicalPlanBuilder :: from ( t1)
490+ . join (
491+ t2,
492+ JoinType :: Left ,
493+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
494+ None ,
495+ ) ?
496+ . filter (
497+ col ( "t2.b" )
498+ . in_list ( vec ! [ lit( 1u32 ) , lit( ScalarValue :: UInt32 ( None ) ) ] , false ) ,
499+ ) ?
500+ . build ( ) ?;
501+
502+ assert_optimized_plan_equal ! ( plan, @r"
503+ Filter: t2.b IN ([UInt32(1), UInt32(NULL)])
504+ Inner Join: t1.a = t2.a
505+ TableScan: t1
506+ TableScan: t2
507+ " )
508+ }
509+
510+ #[ test]
511+ fn eliminate_left_with_not_in_list ( ) -> Result < ( ) > {
512+ let t1 = test_table_scan_with_name ( "t1" ) ?;
513+ let t2 = test_table_scan_with_name ( "t2" ) ?;
514+
515+ // NOT IN also rejects nulls: if t2.b is NULL, NOT (NULL IN (...))
516+ // evaluates to NULL, which is filtered out
517+ let plan = LogicalPlanBuilder :: from ( t1)
518+ . join (
519+ t2,
520+ JoinType :: Left ,
521+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
522+ None ,
523+ ) ?
524+ . filter ( col ( "t2.b" ) . in_list ( vec ! [ lit( 1u32 ) , lit( 2u32 ) ] , true ) ) ?
525+ . build ( ) ?;
526+
527+ assert_optimized_plan_equal ! ( plan, @r"
528+ Filter: t2.b NOT IN ([UInt32(1), UInt32(2)])
529+ Inner Join: t1.a = t2.a
530+ TableScan: t1
531+ TableScan: t2
532+ " )
533+ }
534+
535+ #[ test]
536+ fn eliminate_left_with_between ( ) -> Result < ( ) > {
537+ let t1 = test_table_scan_with_name ( "t1" ) ?;
538+ let t2 = test_table_scan_with_name ( "t2" ) ?;
539+
540+ // BETWEEN rejects nulls: if t2.b is NULL, NULL BETWEEN 1 AND 10
541+ // evaluates to NULL, which is filtered out
542+ let plan = LogicalPlanBuilder :: from ( t1)
543+ . join (
544+ t2,
545+ JoinType :: Left ,
546+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
547+ None ,
548+ ) ?
549+ . filter ( col ( "t2.b" ) . between ( lit ( 1u32 ) , lit ( 10u32 ) ) ) ?
550+ . build ( ) ?;
551+
552+ assert_optimized_plan_equal ! ( plan, @r"
553+ Filter: t2.b BETWEEN UInt32(1) AND UInt32(10)
554+ Inner Join: t1.a = t2.a
555+ TableScan: t1
556+ TableScan: t2
557+ " )
558+ }
559+
560+ #[ test]
561+ fn eliminate_right_with_between ( ) -> Result < ( ) > {
562+ let t1 = test_table_scan_with_name ( "t1" ) ?;
563+ let t2 = test_table_scan_with_name ( "t2" ) ?;
564+
565+ // Right join: filter on left (nullable) side with BETWEEN should convert to Inner
566+ let plan = LogicalPlanBuilder :: from ( t1)
567+ . join (
568+ t2,
569+ JoinType :: Right ,
570+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
571+ None ,
572+ ) ?
573+ . filter ( col ( "t1.b" ) . between ( lit ( 1u32 ) , lit ( 10u32 ) ) ) ?
574+ . build ( ) ?;
575+
576+ assert_optimized_plan_equal ! ( plan, @r"
577+ Filter: t1.b BETWEEN UInt32(1) AND UInt32(10)
578+ Inner Join: t1.a = t2.a
579+ TableScan: t1
580+ TableScan: t2
581+ " )
582+ }
583+
584+ #[ test]
585+ fn eliminate_full_with_between ( ) -> Result < ( ) > {
586+ let t1 = test_table_scan_with_name ( "t1" ) ?;
587+ let t2 = test_table_scan_with_name ( "t2" ) ?;
588+
589+ // Full join with BETWEEN on both sides should become Inner
590+ let plan = LogicalPlanBuilder :: from ( t1)
591+ . join (
592+ t2,
593+ JoinType :: Full ,
594+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
595+ None ,
596+ ) ?
597+ . filter ( binary_expr (
598+ col ( "t1.b" ) . between ( lit ( 1u32 ) , lit ( 10u32 ) ) ,
599+ And ,
600+ col ( "t2.b" ) . between ( lit ( 5u32 ) , lit ( 20u32 ) ) ,
601+ ) ) ?
602+ . build ( ) ?;
603+
604+ assert_optimized_plan_equal ! ( plan, @r"
605+ Filter: t1.b BETWEEN UInt32(1) AND UInt32(10) AND t2.b BETWEEN UInt32(5) AND UInt32(20)
606+ Inner Join: t1.a = t2.a
607+ TableScan: t1
608+ TableScan: t2
609+ " )
610+ }
611+
612+ #[ test]
613+ fn eliminate_full_with_in_list ( ) -> Result < ( ) > {
614+ let t1 = test_table_scan_with_name ( "t1" ) ?;
615+ let t2 = test_table_scan_with_name ( "t2" ) ?;
616+
617+ // Full join with IN filters on both sides should become Inner
618+ let plan = LogicalPlanBuilder :: from ( t1)
619+ . join (
620+ t2,
621+ JoinType :: Full ,
622+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
623+ None ,
624+ ) ?
625+ . filter ( binary_expr (
626+ col ( "t1.b" ) . in_list ( vec ! [ lit( 1u32 ) , lit( 2u32 ) ] , false ) ,
627+ And ,
628+ col ( "t2.b" ) . in_list ( vec ! [ lit( 3u32 ) , lit( 4u32 ) ] , false ) ,
629+ ) ) ?
630+ . build ( ) ?;
631+
632+ assert_optimized_plan_equal ! ( plan, @r"
633+ Filter: t1.b IN ([UInt32(1), UInt32(2)]) AND t2.b IN ([UInt32(3), UInt32(4)])
634+ Inner Join: t1.a = t2.a
635+ TableScan: t1
636+ TableScan: t2
637+ " )
638+ }
639+
640+ #[ test]
641+ fn no_eliminate_left_with_in_list_or_is_null ( ) -> Result < ( ) > {
642+ let t1 = test_table_scan_with_name ( "t1" ) ?;
643+ let t2 = test_table_scan_with_name ( "t2" ) ?;
644+
645+ // WHERE (t2.b IN (1, 2)) OR (t2.b IS NULL)
646+ // The OR with IS NULL makes the predicate null-tolerant:
647+ // when t2.b is NULL, IS NULL returns true, so the whole OR is true.
648+ // The outer join must be preserved.
649+ let plan = LogicalPlanBuilder :: from ( t1)
650+ . join (
651+ t2,
652+ JoinType :: Left ,
653+ ( vec ! [ Column :: from_name( "a" ) ] , vec ! [ Column :: from_name( "a" ) ] ) ,
654+ None ,
655+ ) ?
656+ . filter ( binary_expr (
657+ col ( "t2.b" ) . in_list ( vec ! [ lit( 1u32 ) , lit( 2u32 ) ] , false ) ,
658+ Or ,
659+ col ( "t2.b" ) . is_null ( ) ,
660+ ) ) ?
661+ . build ( ) ?;
662+
663+ // Should NOT be converted to Inner — OR with IS NULL preserves null rows
664+ assert_optimized_plan_equal ! ( plan, @r"
665+ Filter: t2.b IN ([UInt32(1), UInt32(2)]) OR t2.b IS NULL
666+ Left Join: t1.a = t2.a
667+ TableScan: t1
668+ TableScan: t2
669+ " )
670+ }
671+
439672 #[ test]
440673 fn eliminate_full_with_type_cast ( ) -> Result < ( ) > {
441674 let t1 = test_table_scan_with_name ( "t1" ) ?;
0 commit comments