1717
1818//! Utility functions leveraged by the query optimizer rules
1919
20+ mod null_restriction;
21+
2022use std:: collections:: { BTreeSet , HashMap , HashSet } ;
2123use std:: sync:: Arc ;
2224
25+ #[ cfg( test) ]
26+ use std:: cell:: Cell ;
27+
2328use crate :: analyzer:: type_coercion:: TypeCoercionRewriter ;
29+
30+ /// Null restriction evaluation mode for optimizer tests.
31+ #[ cfg( test) ]
32+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
33+ pub ( crate ) enum NullRestrictionEvalMode {
34+ Auto ,
35+ AuthoritativeOnly ,
36+ }
37+
38+ #[ cfg( test) ]
39+ thread_local ! {
40+ static NULL_RESTRICTION_EVAL_MODE : Cell <NullRestrictionEvalMode > =
41+ const { Cell :: new( NullRestrictionEvalMode :: Auto ) } ;
42+ }
43+
44+ #[ cfg( test) ]
45+ pub ( crate ) fn set_null_restriction_eval_mode_for_test ( mode : NullRestrictionEvalMode ) {
46+ NULL_RESTRICTION_EVAL_MODE . with ( |eval_mode| eval_mode. set ( mode) ) ;
47+ }
48+
49+ #[ cfg( test) ]
50+ fn null_restriction_eval_mode ( ) -> NullRestrictionEvalMode {
51+ NULL_RESTRICTION_EVAL_MODE . with ( Cell :: get)
52+ }
53+
54+ #[ cfg( test) ]
55+ pub ( crate ) fn with_null_restriction_eval_mode_for_test < T > (
56+ mode : NullRestrictionEvalMode ,
57+ f : impl FnOnce ( ) -> T ,
58+ ) -> T {
59+ struct NullRestrictionEvalModeReset ( NullRestrictionEvalMode ) ;
60+
61+ impl Drop for NullRestrictionEvalModeReset {
62+ fn drop ( & mut self ) {
63+ set_null_restriction_eval_mode_for_test ( self . 0 ) ;
64+ }
65+ }
66+
67+ let previous_mode = null_restriction_eval_mode ( ) ;
68+ set_null_restriction_eval_mode_for_test ( mode) ;
69+ let _reset = NullRestrictionEvalModeReset ( previous_mode) ;
70+ f ( )
71+ }
2472use arrow:: array:: { Array , RecordBatch , new_null_array} ;
2573use arrow:: datatypes:: { DataType , Field , Schema } ;
2674use datafusion_common:: cast:: as_boolean_array;
@@ -38,9 +86,13 @@ pub use datafusion_expr::expr_rewriter::NamePreserver;
3886
3987/// Returns true if `expr` contains all columns in `schema_cols`
4088pub ( crate ) fn has_all_column_refs ( expr : & Expr , schema_cols : & HashSet < Column > ) -> bool {
41- expr. column_refs ( )
89+ let column_refs = expr. column_refs ( ) ;
90+ // note can't use HashSet::intersect because of different types (owned vs References)
91+ schema_cols
4292 . iter ( )
43- . all ( |column| schema_cols. contains ( * column) )
93+ . filter ( |c| column_refs. contains ( c) )
94+ . count ( )
95+ == column_refs. len ( )
4496}
4597
4698pub ( crate ) fn replace_qualified_name (
@@ -78,19 +130,42 @@ pub fn is_restrict_null_predicate<'a>(
78130 // Collect join columns so they can be used in both the fast-path check and the
79131 // fallback evaluation path below.
80132 let join_cols: HashSet < & Column > = join_cols_of_predicate. into_iter ( ) . collect ( ) ;
133+ let column_refs = predicate. column_refs ( ) ;
134+
81135 // Fast path: if the predicate references columns outside the join key set,
82136 // `evaluate_expr_with_null_column` would fail because the null schema only
83137 // contains a placeholder for the join key columns. Callers treat such errors as
84138 // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early
85139 // and avoid the expensive physical-expression compilation pipeline entirely.
86- if !predicate
87- . column_refs ( )
88- . iter ( )
89- . all ( |column| join_cols. contains ( * column) )
90- {
140+ if !null_restriction:: all_columns_allowed ( & column_refs, & join_cols) {
91141 return Ok ( false ) ;
92142 }
93143
144+ #[ cfg( test) ]
145+ if matches ! (
146+ null_restriction_eval_mode( ) ,
147+ NullRestrictionEvalMode :: AuthoritativeOnly
148+ ) {
149+ return authoritative_restrict_null_predicate ( predicate, join_cols) ;
150+ }
151+
152+ if let Some ( is_restricting) =
153+ null_restriction:: syntactic_restrict_null_predicate ( & predicate, & join_cols)
154+ {
155+ #[ cfg( debug_assertions) ]
156+ {
157+ let authoritative = authoritative_restrict_null_predicate (
158+ predicate. clone ( ) ,
159+ join_cols. iter ( ) . copied ( ) ,
160+ ) ?;
161+ debug_assert_eq ! (
162+ is_restricting, authoritative,
163+ "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}"
164+ ) ;
165+ }
166+ return Ok ( is_restricting) ;
167+ }
168+
94169 authoritative_restrict_null_predicate ( predicate, join_cols)
95170}
96171
@@ -147,11 +222,14 @@ fn authoritative_restrict_null_predicate<'a>(
147222) -> Result < bool > {
148223 Ok (
149224 match evaluate_expr_with_null_column ( predicate, join_cols_of_predicate) ? {
150- ColumnarValue :: Array ( array) if array. len ( ) == 1 => {
151- let boolean_array = as_boolean_array ( & array) ?;
152- boolean_array. is_null ( 0 ) || !boolean_array. value ( 0 )
225+ ColumnarValue :: Array ( array) => {
226+ if array. len ( ) == 1 {
227+ let boolean_array = as_boolean_array ( & array) ?;
228+ boolean_array. is_null ( 0 ) || !boolean_array. value ( 0 )
229+ } else {
230+ false
231+ }
153232 }
154- ColumnarValue :: Array ( _) => false ,
155233 ColumnarValue :: Scalar ( scalar) => matches ! (
156234 scalar,
157235 ScalarValue :: Boolean ( None )
@@ -170,6 +248,8 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
170248#[ cfg( test) ]
171249mod tests {
172250 use super :: * ;
251+ use std:: panic:: { AssertUnwindSafe , catch_unwind} ;
252+
173253 use datafusion_expr:: {
174254 Operator , binary_expr, case, col, in_list, is_null, lit, when,
175255 } ;
@@ -316,4 +396,200 @@ mod tests {
316396
317397 Ok ( ( ) )
318398 }
399+
400+ #[ test]
401+ fn syntactic_fast_path_matches_authoritative_evaluator ( ) -> Result < ( ) > {
402+ let test_cases = vec ! [
403+ is_null( col( "a" ) ) ,
404+ Expr :: IsNotNull ( Box :: new( col( "a" ) ) ) ,
405+ binary_expr( col( "a" ) , Operator :: Gt , lit( 8i64 ) ) ,
406+ binary_expr( col( "a" ) , Operator :: Eq , lit( ScalarValue :: Null ) ) ,
407+ binary_expr( col( "a" ) , Operator :: And , lit( true ) ) ,
408+ binary_expr( col( "a" ) , Operator :: Or , lit( false ) ) ,
409+ Expr :: Not ( Box :: new( col( "a" ) . is_true( ) ) ) ,
410+ col( "a" ) . is_true( ) ,
411+ col( "a" ) . is_false( ) ,
412+ col( "a" ) . is_unknown( ) ,
413+ col( "a" ) . is_not_true( ) ,
414+ col( "a" ) . is_not_false( ) ,
415+ col( "a" ) . is_not_unknown( ) ,
416+ col( "a" ) . between( lit( 1i64 ) , lit( 10i64 ) ) ,
417+ binary_expr(
418+ when( Expr :: IsNotNull ( Box :: new( col( "a" ) ) ) , col( "a" ) )
419+ . otherwise( col( "b" ) ) ?,
420+ Operator :: Gt ,
421+ lit( 2i64 ) ,
422+ ) ,
423+ case( col( "a" ) )
424+ . when( lit( 1i64 ) , lit( true ) )
425+ . otherwise( lit( false ) ) ?,
426+ case( col( "a" ) )
427+ . when( lit( 0i64 ) , lit( false ) )
428+ . otherwise( lit( true ) ) ?,
429+ binary_expr(
430+ case( col( "a" ) )
431+ . when( lit( 0i64 ) , lit( true ) )
432+ . otherwise( lit( false ) ) ?,
433+ Operator :: Or ,
434+ lit( false ) ,
435+ ) ,
436+ binary_expr(
437+ case( lit( 1i64 ) )
438+ . when( lit( 1i64 ) , lit( ScalarValue :: Null ) )
439+ . otherwise( lit( false ) ) ?,
440+ Operator :: IsNotDistinctFrom ,
441+ lit( true ) ,
442+ ) ,
443+ binary_expr( col( "a" ) . is_true( ) , Operator :: And , lit( true ) ) ,
444+ binary_expr( col( "a" ) . is_false( ) , Operator :: Or , lit( false ) ) ,
445+ binary_expr( col( "a" ) . is_unknown( ) , Operator :: And , is_null( col( "a" ) ) ) ,
446+ binary_expr(
447+ Expr :: Not ( Box :: new( col( "a" ) . is_not_unknown( ) ) ) ,
448+ Operator :: Or ,
449+ Expr :: IsNotNull ( Box :: new( col( "a" ) ) ) ,
450+ ) ,
451+ ] ;
452+
453+ for predicate in test_cases {
454+ let join_cols = predicate. column_refs ( ) ;
455+ if let Some ( syntactic) = null_restriction:: syntactic_restrict_null_predicate (
456+ & predicate, & join_cols,
457+ ) {
458+ let authoritative = authoritative_restrict_null_predicate (
459+ predicate. clone ( ) ,
460+ join_cols. iter ( ) . copied ( ) ,
461+ )
462+ . unwrap_or_else ( |error| {
463+ panic ! (
464+ "authoritative evaluator failed for predicate `{predicate}`: {error}"
465+ )
466+ } ) ;
467+ assert_eq ! (
468+ syntactic, authoritative,
469+ "syntactic fast path disagrees with authoritative evaluator for predicate: {predicate}" ,
470+ ) ;
471+ }
472+ }
473+
474+ Ok ( ( ) )
475+ }
476+
477+ #[ test]
478+ fn unsupported_boolean_wrappers_defer_to_authoritative_evaluator ( ) -> Result < ( ) > {
479+ let predicates = vec ! [
480+ binary_expr( col( "a" ) . is_true( ) , Operator :: And , lit( true ) ) ,
481+ binary_expr( col( "a" ) . is_false( ) , Operator :: Or , lit( false ) ) ,
482+ binary_expr( col( "a" ) . is_unknown( ) , Operator :: And , is_null( col( "a" ) ) ) ,
483+ binary_expr(
484+ Expr :: Not ( Box :: new( col( "a" ) . is_not_unknown( ) ) ) ,
485+ Operator :: Or ,
486+ Expr :: IsNotNull ( Box :: new( col( "a" ) ) ) ,
487+ ) ,
488+ ] ;
489+
490+ for predicate in predicates {
491+ let join_cols = predicate. column_refs ( ) ;
492+ assert ! (
493+ null_restriction:: syntactic_restrict_null_predicate(
494+ & predicate, & join_cols
495+ )
496+ . is_none( ) ,
497+ "syntactic fast path should defer for predicate: {predicate}" ,
498+ ) ;
499+
500+ let auto_result = with_null_restriction_eval_mode_for_test (
501+ NullRestrictionEvalMode :: Auto ,
502+ || {
503+ is_restrict_null_predicate (
504+ predicate. clone ( ) ,
505+ join_cols. iter ( ) . copied ( ) ,
506+ )
507+ } ,
508+ ) ?;
509+
510+ let authoritative_result = with_null_restriction_eval_mode_for_test (
511+ NullRestrictionEvalMode :: AuthoritativeOnly ,
512+ || {
513+ is_restrict_null_predicate (
514+ predicate. clone ( ) ,
515+ join_cols. iter ( ) . copied ( ) ,
516+ )
517+ } ,
518+ ) ?;
519+
520+ assert_eq ! (
521+ auto_result, authoritative_result,
522+ "auto mode should defer to authoritative evaluation for predicate: {predicate}" ,
523+ ) ;
524+ }
525+
526+ Ok ( ( ) )
527+ }
528+
529+ #[ test]
530+ fn null_restriction_eval_mode_auto_vs_authoritative_only ( ) -> Result < ( ) > {
531+ let predicate = binary_expr ( col ( "a" ) , Operator :: Gt , lit ( 8i64 ) ) ;
532+ let join_cols_of_predicate = predicate. column_refs ( ) ;
533+
534+ let auto_result = with_null_restriction_eval_mode_for_test (
535+ NullRestrictionEvalMode :: Auto ,
536+ || {
537+ is_restrict_null_predicate (
538+ predicate. clone ( ) ,
539+ join_cols_of_predicate. iter ( ) . copied ( ) ,
540+ )
541+ } ,
542+ ) ?;
543+
544+ let authoritative_result = with_null_restriction_eval_mode_for_test (
545+ NullRestrictionEvalMode :: AuthoritativeOnly ,
546+ || {
547+ is_restrict_null_predicate (
548+ predicate. clone ( ) ,
549+ join_cols_of_predicate. iter ( ) . copied ( ) ,
550+ )
551+ } ,
552+ ) ?;
553+
554+ assert_eq ! ( auto_result, authoritative_result) ;
555+
556+ Ok ( ( ) )
557+ }
558+
559+ #[ test]
560+ fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode ( ) -> Result < ( ) >
561+ {
562+ let predicate = binary_expr ( col ( "a" ) , Operator :: Gt , col ( "b" ) ) ;
563+ let column_a = Column :: from_name ( "a" ) ;
564+
565+ let auto_result = with_null_restriction_eval_mode_for_test (
566+ NullRestrictionEvalMode :: Auto ,
567+ || is_restrict_null_predicate ( predicate. clone ( ) , std:: iter:: once ( & column_a) ) ,
568+ ) ?;
569+
570+ let authoritative_only_result = with_null_restriction_eval_mode_for_test (
571+ NullRestrictionEvalMode :: AuthoritativeOnly ,
572+ || is_restrict_null_predicate ( predicate. clone ( ) , std:: iter:: once ( & column_a) ) ,
573+ ) ?;
574+
575+ assert ! ( !auto_result, "{predicate}" ) ;
576+ assert ! ( !authoritative_only_result, "{predicate}" ) ;
577+
578+ Ok ( ( ) )
579+ }
580+
581+ #[ test]
582+ fn null_restriction_eval_mode_guard_restores_on_panic ( ) {
583+ set_null_restriction_eval_mode_for_test ( NullRestrictionEvalMode :: Auto ) ;
584+
585+ let result = catch_unwind ( AssertUnwindSafe ( || {
586+ with_null_restriction_eval_mode_for_test (
587+ NullRestrictionEvalMode :: AuthoritativeOnly ,
588+ || panic ! ( "intentional panic to verify test mode reset" ) ,
589+ )
590+ } ) ) ;
591+
592+ assert ! ( result. is_err( ) ) ;
593+ assert_eq ! ( null_restriction_eval_mode( ) , NullRestrictionEvalMode :: Auto ) ;
594+ }
319595}
0 commit comments