Skip to content

Commit 3bd2b87

Browse files
committed
Fix mixed boolean handling in null_restriction
Update binary_boolean_value to return None for mixed boolean states instead of asserting. This change allows is_restrict_null_predicate to use the authoritative evaluator, preventing panics from unsupported cases. Add regression tests to ensure proper behavior in unsupported boolean-wrapper scenarios, maintaining consistent functionality between auto mode and authoritative mode.
1 parent 5989c6e commit 3bd2b87

File tree

2 files changed

+290
-15
lines changed

2 files changed

+290
-15
lines changed

datafusion/optimizer/src/utils.rs

Lines changed: 287 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,58 @@
1717

1818
//! Utility functions leveraged by the query optimizer rules
1919
20+
mod null_restriction;
21+
2022
use std::collections::{BTreeSet, HashMap, HashSet};
2123
use std::sync::Arc;
2224

25+
#[cfg(test)]
26+
use std::cell::Cell;
27+
2328
use crate::analyzer::type_coercion::TypeCoercionRewriter;
29+
30+
/// Null restriction evaluation mode for optimizer tests.
31+
#[cfg(test)]
32+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
33+
pub(crate) enum NullRestrictionEvalMode {
34+
Auto,
35+
AuthoritativeOnly,
36+
}
37+
38+
#[cfg(test)]
39+
thread_local! {
40+
static NULL_RESTRICTION_EVAL_MODE: Cell<NullRestrictionEvalMode> =
41+
const { Cell::new(NullRestrictionEvalMode::Auto) };
42+
}
43+
44+
#[cfg(test)]
45+
pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) {
46+
NULL_RESTRICTION_EVAL_MODE.with(|eval_mode| eval_mode.set(mode));
47+
}
48+
49+
#[cfg(test)]
50+
fn null_restriction_eval_mode() -> NullRestrictionEvalMode {
51+
NULL_RESTRICTION_EVAL_MODE.with(Cell::get)
52+
}
53+
54+
#[cfg(test)]
55+
pub(crate) fn with_null_restriction_eval_mode_for_test<T>(
56+
mode: NullRestrictionEvalMode,
57+
f: impl FnOnce() -> T,
58+
) -> T {
59+
struct NullRestrictionEvalModeReset(NullRestrictionEvalMode);
60+
61+
impl Drop for NullRestrictionEvalModeReset {
62+
fn drop(&mut self) {
63+
set_null_restriction_eval_mode_for_test(self.0);
64+
}
65+
}
66+
67+
let previous_mode = null_restriction_eval_mode();
68+
set_null_restriction_eval_mode_for_test(mode);
69+
let _reset = NullRestrictionEvalModeReset(previous_mode);
70+
f()
71+
}
2472
use arrow::array::{Array, RecordBatch, new_null_array};
2573
use arrow::datatypes::{DataType, Field, Schema};
2674
use datafusion_common::cast::as_boolean_array;
@@ -38,9 +86,13 @@ pub use datafusion_expr::expr_rewriter::NamePreserver;
3886

3987
/// Returns true if `expr` contains all columns in `schema_cols`
4088
pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet<Column>) -> bool {
41-
expr.column_refs()
89+
let column_refs = expr.column_refs();
90+
// note can't use HashSet::intersect because of different types (owned vs References)
91+
schema_cols
4292
.iter()
43-
.all(|column| schema_cols.contains(*column))
93+
.filter(|c| column_refs.contains(c))
94+
.count()
95+
== column_refs.len()
4496
}
4597

4698
pub(crate) fn replace_qualified_name(
@@ -78,19 +130,42 @@ pub fn is_restrict_null_predicate<'a>(
78130
// Collect join columns so they can be used in both the fast-path check and the
79131
// fallback evaluation path below.
80132
let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect();
133+
let column_refs = predicate.column_refs();
134+
81135
// Fast path: if the predicate references columns outside the join key set,
82136
// `evaluate_expr_with_null_column` would fail because the null schema only
83137
// contains a placeholder for the join key columns. Callers treat such errors as
84138
// non-restricting (false) via `matches!(_, Ok(true))`, so we return false early
85139
// and avoid the expensive physical-expression compilation pipeline entirely.
86-
if !predicate
87-
.column_refs()
88-
.iter()
89-
.all(|column| join_cols.contains(*column))
90-
{
140+
if !null_restriction::all_columns_allowed(&column_refs, &join_cols) {
91141
return Ok(false);
92142
}
93143

144+
#[cfg(test)]
145+
if matches!(
146+
null_restriction_eval_mode(),
147+
NullRestrictionEvalMode::AuthoritativeOnly
148+
) {
149+
return authoritative_restrict_null_predicate(predicate, join_cols);
150+
}
151+
152+
if let Some(is_restricting) =
153+
null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols)
154+
{
155+
#[cfg(debug_assertions)]
156+
{
157+
let authoritative = authoritative_restrict_null_predicate(
158+
predicate.clone(),
159+
join_cols.iter().copied(),
160+
)?;
161+
debug_assert_eq!(
162+
is_restricting, authoritative,
163+
"syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}"
164+
);
165+
}
166+
return Ok(is_restricting);
167+
}
168+
94169
authoritative_restrict_null_predicate(predicate, join_cols)
95170
}
96171

@@ -147,11 +222,14 @@ fn authoritative_restrict_null_predicate<'a>(
147222
) -> Result<bool> {
148223
Ok(
149224
match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
150-
ColumnarValue::Array(array) if array.len() == 1 => {
151-
let boolean_array = as_boolean_array(&array)?;
152-
boolean_array.is_null(0) || !boolean_array.value(0)
225+
ColumnarValue::Array(array) => {
226+
if array.len() == 1 {
227+
let boolean_array = as_boolean_array(&array)?;
228+
boolean_array.is_null(0) || !boolean_array.value(0)
229+
} else {
230+
false
231+
}
153232
}
154-
ColumnarValue::Array(_) => false,
155233
ColumnarValue::Scalar(scalar) => matches!(
156234
scalar,
157235
ScalarValue::Boolean(None)
@@ -170,6 +248,8 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
170248
#[cfg(test)]
171249
mod tests {
172250
use super::*;
251+
use std::panic::{AssertUnwindSafe, catch_unwind};
252+
173253
use datafusion_expr::{
174254
Operator, binary_expr, case, col, in_list, is_null, lit, when,
175255
};
@@ -316,4 +396,200 @@ mod tests {
316396

317397
Ok(())
318398
}
399+
400+
#[test]
401+
fn syntactic_fast_path_matches_authoritative_evaluator() -> Result<()> {
402+
let test_cases = vec![
403+
is_null(col("a")),
404+
Expr::IsNotNull(Box::new(col("a"))),
405+
binary_expr(col("a"), Operator::Gt, lit(8i64)),
406+
binary_expr(col("a"), Operator::Eq, lit(ScalarValue::Null)),
407+
binary_expr(col("a"), Operator::And, lit(true)),
408+
binary_expr(col("a"), Operator::Or, lit(false)),
409+
Expr::Not(Box::new(col("a").is_true())),
410+
col("a").is_true(),
411+
col("a").is_false(),
412+
col("a").is_unknown(),
413+
col("a").is_not_true(),
414+
col("a").is_not_false(),
415+
col("a").is_not_unknown(),
416+
col("a").between(lit(1i64), lit(10i64)),
417+
binary_expr(
418+
when(Expr::IsNotNull(Box::new(col("a"))), col("a"))
419+
.otherwise(col("b"))?,
420+
Operator::Gt,
421+
lit(2i64),
422+
),
423+
case(col("a"))
424+
.when(lit(1i64), lit(true))
425+
.otherwise(lit(false))?,
426+
case(col("a"))
427+
.when(lit(0i64), lit(false))
428+
.otherwise(lit(true))?,
429+
binary_expr(
430+
case(col("a"))
431+
.when(lit(0i64), lit(true))
432+
.otherwise(lit(false))?,
433+
Operator::Or,
434+
lit(false),
435+
),
436+
binary_expr(
437+
case(lit(1i64))
438+
.when(lit(1i64), lit(ScalarValue::Null))
439+
.otherwise(lit(false))?,
440+
Operator::IsNotDistinctFrom,
441+
lit(true),
442+
),
443+
binary_expr(col("a").is_true(), Operator::And, lit(true)),
444+
binary_expr(col("a").is_false(), Operator::Or, lit(false)),
445+
binary_expr(col("a").is_unknown(), Operator::And, is_null(col("a"))),
446+
binary_expr(
447+
Expr::Not(Box::new(col("a").is_not_unknown())),
448+
Operator::Or,
449+
Expr::IsNotNull(Box::new(col("a"))),
450+
),
451+
];
452+
453+
for predicate in test_cases {
454+
let join_cols = predicate.column_refs();
455+
if let Some(syntactic) = null_restriction::syntactic_restrict_null_predicate(
456+
&predicate, &join_cols,
457+
) {
458+
let authoritative = authoritative_restrict_null_predicate(
459+
predicate.clone(),
460+
join_cols.iter().copied(),
461+
)
462+
.unwrap_or_else(|error| {
463+
panic!(
464+
"authoritative evaluator failed for predicate `{predicate}`: {error}"
465+
)
466+
});
467+
assert_eq!(
468+
syntactic, authoritative,
469+
"syntactic fast path disagrees with authoritative evaluator for predicate: {predicate}",
470+
);
471+
}
472+
}
473+
474+
Ok(())
475+
}
476+
477+
#[test]
478+
fn unsupported_boolean_wrappers_defer_to_authoritative_evaluator() -> Result<()> {
479+
let predicates = vec![
480+
binary_expr(col("a").is_true(), Operator::And, lit(true)),
481+
binary_expr(col("a").is_false(), Operator::Or, lit(false)),
482+
binary_expr(col("a").is_unknown(), Operator::And, is_null(col("a"))),
483+
binary_expr(
484+
Expr::Not(Box::new(col("a").is_not_unknown())),
485+
Operator::Or,
486+
Expr::IsNotNull(Box::new(col("a"))),
487+
),
488+
];
489+
490+
for predicate in predicates {
491+
let join_cols = predicate.column_refs();
492+
assert!(
493+
null_restriction::syntactic_restrict_null_predicate(
494+
&predicate, &join_cols
495+
)
496+
.is_none(),
497+
"syntactic fast path should defer for predicate: {predicate}",
498+
);
499+
500+
let auto_result = with_null_restriction_eval_mode_for_test(
501+
NullRestrictionEvalMode::Auto,
502+
|| {
503+
is_restrict_null_predicate(
504+
predicate.clone(),
505+
join_cols.iter().copied(),
506+
)
507+
},
508+
)?;
509+
510+
let authoritative_result = with_null_restriction_eval_mode_for_test(
511+
NullRestrictionEvalMode::AuthoritativeOnly,
512+
|| {
513+
is_restrict_null_predicate(
514+
predicate.clone(),
515+
join_cols.iter().copied(),
516+
)
517+
},
518+
)?;
519+
520+
assert_eq!(
521+
auto_result, authoritative_result,
522+
"auto mode should defer to authoritative evaluation for predicate: {predicate}",
523+
);
524+
}
525+
526+
Ok(())
527+
}
528+
529+
#[test]
530+
fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> {
531+
let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64));
532+
let join_cols_of_predicate = predicate.column_refs();
533+
534+
let auto_result = with_null_restriction_eval_mode_for_test(
535+
NullRestrictionEvalMode::Auto,
536+
|| {
537+
is_restrict_null_predicate(
538+
predicate.clone(),
539+
join_cols_of_predicate.iter().copied(),
540+
)
541+
},
542+
)?;
543+
544+
let authoritative_result = with_null_restriction_eval_mode_for_test(
545+
NullRestrictionEvalMode::AuthoritativeOnly,
546+
|| {
547+
is_restrict_null_predicate(
548+
predicate.clone(),
549+
join_cols_of_predicate.iter().copied(),
550+
)
551+
},
552+
)?;
553+
554+
assert_eq!(auto_result, authoritative_result);
555+
556+
Ok(())
557+
}
558+
559+
#[test]
560+
fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode() -> Result<()>
561+
{
562+
let predicate = binary_expr(col("a"), Operator::Gt, col("b"));
563+
let column_a = Column::from_name("a");
564+
565+
let auto_result = with_null_restriction_eval_mode_for_test(
566+
NullRestrictionEvalMode::Auto,
567+
|| is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)),
568+
)?;
569+
570+
let authoritative_only_result = with_null_restriction_eval_mode_for_test(
571+
NullRestrictionEvalMode::AuthoritativeOnly,
572+
|| is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)),
573+
)?;
574+
575+
assert!(!auto_result, "{predicate}");
576+
assert!(!authoritative_only_result, "{predicate}");
577+
578+
Ok(())
579+
}
580+
581+
#[test]
582+
fn null_restriction_eval_mode_guard_restores_on_panic() {
583+
set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto);
584+
585+
let result = catch_unwind(AssertUnwindSafe(|| {
586+
with_null_restriction_eval_mode_for_test(
587+
NullRestrictionEvalMode::AuthoritativeOnly,
588+
|| panic!("intentional panic to verify test mode reset"),
589+
)
590+
}));
591+
592+
assert!(result.is_err());
593+
assert_eq!(null_restriction_eval_mode(), NullRestrictionEvalMode::Auto);
594+
}
319595
}

datafusion/optimizer/src/utils/null_restriction.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,9 @@ fn binary_boolean_value(
8484
| (_, Some(NullSubstitutionValue::NonNull))
8585
| (None, _)
8686
| (_, None) => None,
87-
(left, right) => {
88-
debug_assert_eq!(left, right);
89-
left
90-
}
87+
// Any remaining mixed state is outside the reduced lattice this syntactic
88+
// evaluator can model soundly. Defer to the authoritative evaluator.
89+
_ => None,
9190
}
9291
}
9392

0 commit comments

Comments
 (0)