From b9828cabc4ac8b26cba2d246f7de52781c20ad8f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 16 Mar 2026 16:02:27 +0800 Subject: [PATCH 01/63] fix: optimize null predicate evaluation by early exit for non-restricting conditions --- datafusion/optimizer/src/utils.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 7e038d2392022..b4d10954b8b47 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -79,10 +79,27 @@ pub fn is_restrict_null_predicate<'a>( return Ok(true); } + // Collect join columns so they can be used in both the fast-path check and the + // fallback evaluation path below. + let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect(); + + // Fast path: if the predicate references columns outside the join key set, + // `evaluate_expr_with_null_column` would fail because the null schema only + // contains a placeholder for the join key columns. Callers treat such errors as + // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early + // and avoid the expensive physical-expression compilation pipeline entirely. + if predicate + .column_refs() + .iter() + .any(|c| !join_cols.contains(*c)) + { + return Ok(false); + } + // If result is single `true`, return false; // If result is single `NULL` or `false`, return true; Ok( - match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + match evaluate_expr_with_null_column(predicate, join_cols.into_iter())? { ColumnarValue::Array(array) => { if array.len() == 1 { let boolean_array = as_boolean_array(&array)?; From 88e4455c6f1041c84a1056d4b00235c73c3a26d6 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 16 Mar 2026 21:53:38 +0800 Subject: [PATCH 02/63] Add test case for a > b in join key predicates Introduce a test case to assert non-restricting behavior when evaluating the predicate a > b, focusing on join keys that only include a. This directly tests the new early-return branch in the is_restrict_null_predicate function in utils.rs, enhancing overall code coverage. --- datafusion/optimizer/src/utils.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index b4d10954b8b47..c5a31978fa34a 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -195,6 +195,8 @@ mod tests { (binary_expr(col("a"), Operator::Gt, lit(8i64)), true), // a <= 8 (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true), + // a > b (b is outside join key set) + (binary_expr(col("a"), Operator::Gt, col("b")), false), // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END ( case(col("a")) From df73001666d05a22ae93152a50cd3cc4454e7dd4 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 16 Mar 2026 21:55:07 +0800 Subject: [PATCH 03/63] Refactor column membership check into a helper function Extract the column-membership check into a new helper function called `predicate_uses_only_columns` in utils.rs. Update the current implementation at utils.rs:91 to use this new helper, improving code readability and maintainability. --- datafusion/optimizer/src/utils.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index c5a31978fa34a..6b459156f2a6a 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -88,11 +88,7 @@ pub fn is_restrict_null_predicate<'a>( // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if predicate - .column_refs() - .iter() - .any(|c| !join_cols.contains(*c)) - { + if !predicate_uses_only_columns(&predicate, &join_cols) { return Ok(false); } @@ -116,6 +112,16 @@ pub fn is_restrict_null_predicate<'a>( ) } +fn predicate_uses_only_columns( + predicate: &Expr, + allowed_columns: &HashSet<&Column>, +) -> bool { + predicate + .column_refs() + .iter() + .all(|column| allowed_columns.contains(*column)) +} + /// Determines if an expression will always evaluate to null. /// `c0 + 8` return true /// `c0 IS NULL` return false From 144cab304d4be31c93b622f505af1c623494e6e3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 16 Mar 2026 21:56:38 +0800 Subject: [PATCH 04/63] Clarify null-restricting behavior in filter check Add call-site contract comment in push_down_filter.rs to specify that only Ok(true) is treated as null-restricting. State that both Ok(false) and Err(_) are considered non-restricting and will be skipped during processing. --- datafusion/optimizer/src/push_down_filter.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 03a7a0b864177..3410c1f055c32 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -622,6 +622,8 @@ impl InferredPredicates { predicate: Expr, replace_map: &HashMap<&Column, &Column>, ) -> Result<()> { + // Contract: only `Ok(true)` is considered null-restricting for non-inner joins. + // `Ok(false)` and `Err(_)` are both treated as non-restricting and skipped. if self.is_inner_join || matches!( is_restrict_null_predicate( From 9ef45f015677a68d207ad9814613ec1508bf3c07 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 16 Mar 2026 22:01:42 +0800 Subject: [PATCH 05/63] Simplify column check and null-restrict handling Inline iterator predicate in utils.rs and streamline the null-restrict handling in push_down_filter.rs. This reduces indirections and lines of code while maintaining the same logic and behavior. No public interface or behavior changes intended. --- datafusion/optimizer/src/push_down_filter.rs | 12 ++++-------- datafusion/optimizer/src/utils.rs | 16 +++++----------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 3410c1f055c32..055008ec00fd7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -622,16 +622,12 @@ impl InferredPredicates { predicate: Expr, replace_map: &HashMap<&Column, &Column>, ) -> Result<()> { - // Contract: only `Ok(true)` is considered null-restricting for non-inner joins. - // `Ok(false)` and `Err(_)` are both treated as non-restricting and skipped. if self.is_inner_join - || matches!( - is_restrict_null_predicate( - predicate.clone(), - replace_map.keys().cloned() - ), - Ok(true) + || is_restrict_null_predicate( + predicate.clone(), + replace_map.keys().cloned(), ) + .unwrap_or(false) { self.predicates.push(replace_col(predicate, replace_map)?); } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6b459156f2a6a..ca1d618a7e0b7 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -88,7 +88,11 @@ pub fn is_restrict_null_predicate<'a>( // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !predicate_uses_only_columns(&predicate, &join_cols) { + if !predicate + .column_refs() + .iter() + .all(|column| join_cols.contains(*column)) + { return Ok(false); } @@ -112,16 +116,6 @@ pub fn is_restrict_null_predicate<'a>( ) } -fn predicate_uses_only_columns( - predicate: &Expr, - allowed_columns: &HashSet<&Column>, -) -> bool { - predicate - .column_refs() - .iter() - .all(|column| allowed_columns.contains(*column)) -} - /// Determines if an expression will always evaluate to null. /// `c0 + 8` return true /// `c0 IS NULL` return false From 3d3945ce07b7015c11b0a4f89f3b456d785b7bdf Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 16 Mar 2026 22:05:07 +0800 Subject: [PATCH 06/63] refactor: streamline null predicate evaluation by introducing predicate_uses_only_columns function --- datafusion/optimizer/src/push_down_filter.rs | 7 ++----- datafusion/optimizer/src/utils.rs | 16 +++++++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 055008ec00fd7..9562ab818824f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -623,11 +623,8 @@ impl InferredPredicates { replace_map: &HashMap<&Column, &Column>, ) -> Result<()> { if self.is_inner_join - || is_restrict_null_predicate( - predicate.clone(), - replace_map.keys().cloned(), - ) - .unwrap_or(false) + || is_restrict_null_predicate(predicate.clone(), replace_map.keys().cloned()) + .unwrap_or(false) { self.predicates.push(replace_col(predicate, replace_map)?); } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index ca1d618a7e0b7..6b459156f2a6a 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -88,11 +88,7 @@ pub fn is_restrict_null_predicate<'a>( // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !predicate - .column_refs() - .iter() - .all(|column| join_cols.contains(*column)) - { + if !predicate_uses_only_columns(&predicate, &join_cols) { return Ok(false); } @@ -116,6 +112,16 @@ pub fn is_restrict_null_predicate<'a>( ) } +fn predicate_uses_only_columns( + predicate: &Expr, + allowed_columns: &HashSet<&Column>, +) -> bool { + predicate + .column_refs() + .iter() + .all(|column| allowed_columns.contains(*column)) +} + /// Determines if an expression will always evaluate to null. /// `c0 + 8` return true /// `c0 IS NULL` return false From 17009ae115d737460f82addadb109cad6581d6ee Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 18 Mar 2026 22:51:28 +0800 Subject: [PATCH 07/63] Improve SQL boolean/null semantics handling Implement fast path for syntactic null-restriction in utils.rs to classify predicates without evaluating physical expressions. Enhance SQL boolean handling with a large supporting evaluator, including CASE management. Retain existing branch helper styles and expand test coverage for constant simple CASE and outside-join-key fast paths. Fix correctness edge case for simple CASE indeterminate comparisons, ensuring proper tracking and fallback to Unknown. --- datafusion/optimizer/src/utils.rs | 551 +++++++++++++++++++++++++++++- 1 file changed, 544 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6b459156f2a6a..33d1f9d01d0ac 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -92,6 +92,12 @@ pub fn is_restrict_null_predicate<'a>( return Ok(false); } + if let Some(is_restricting) = + syntactic_restrict_null_predicate(&predicate, &join_cols) + { + return Ok(is_restricting); + } + // If result is single `true`, return false; // If result is single `NULL` or `false`, return true; Ok( @@ -112,6 +118,35 @@ pub fn is_restrict_null_predicate<'a>( ) } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum NullSubstitutionValue { + Unknown, + Null, + NonNull, + Boolean(bool), +} + +impl NullSubstitutionValue { + fn is_null(self) -> bool { + matches!(self, Self::Null) + } + + fn is_definitely_non_null(self) -> bool { + matches!(self, Self::NonNull | Self::Boolean(_)) + } +} + +fn syntactic_restrict_null_predicate( + predicate: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match syntactic_null_substitution_value(predicate, join_cols) { + NullSubstitutionValue::Boolean(true) => Some(false), + NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null => Some(true), + NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => None, + } +} + fn predicate_uses_only_columns( predicate: &Expr, allowed_columns: &HashSet<&Column>, @@ -122,6 +157,468 @@ fn predicate_uses_only_columns( .all(|column| allowed_columns.contains(*column)) } +fn syntactic_null_substitution_value( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> NullSubstitutionValue { + match expr { + Expr::Alias(alias) => { + syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) + } + Expr::Column(column) => { + if join_cols.contains(column) { + NullSubstitutionValue::Null + } else { + NullSubstitutionValue::Unknown + } + } + Expr::Literal(value, _) => scalar_to_null_substitution_value(value), + Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), + Expr::Not(expr) => { + sql_not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + } + Expr::IsNull(expr) => { + match syntactic_null_substitution_value(expr.as_ref(), join_cols) { + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(true), + NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { + NullSubstitutionValue::Boolean(false) + } + NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, + } + } + Expr::IsNotNull(expr) => { + match syntactic_null_substitution_value(expr.as_ref(), join_cols) { + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(false), + NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { + NullSubstitutionValue::Boolean(true) + } + NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, + } + } + Expr::IsTrue(expr) => boolean_test_result( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + false, + ), + Expr::IsFalse(expr) => { + match syntactic_null_substitution_value(expr.as_ref(), join_cols) { + NullSubstitutionValue::Boolean(value) => { + NullSubstitutionValue::Boolean(!value) + } + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(false), + NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { + NullSubstitutionValue::Unknown + } + } + } + Expr::IsUnknown(expr) => { + match syntactic_null_substitution_value(expr.as_ref(), join_cols) { + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(true), + NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { + NullSubstitutionValue::Boolean(false) + } + NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, + } + } + Expr::IsNotTrue(expr) => boolean_test_result( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + ), + Expr::IsNotFalse(expr) => { + match syntactic_null_substitution_value(expr.as_ref(), join_cols) { + NullSubstitutionValue::Boolean(value) => { + NullSubstitutionValue::Boolean(value) + } + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(true), + NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { + NullSubstitutionValue::Unknown + } + } + } + Expr::IsNotUnknown(expr) => { + match syntactic_null_substitution_value(expr.as_ref(), join_cols) { + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(false), + NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { + NullSubstitutionValue::Boolean(true) + } + NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, + } + } + Expr::Between(between) => { + let expr = + syntactic_null_substitution_value(between.expr.as_ref(), join_cols); + let low = syntactic_null_substitution_value(between.low.as_ref(), join_cols); + let high = + syntactic_null_substitution_value(between.high.as_ref(), join_cols); + if expr.is_null() || low.is_null() || high.is_null() { + NullSubstitutionValue::Null + } else { + NullSubstitutionValue::Unknown + } + } + Expr::Case(case) => syntactic_case_value(case, join_cols), + Expr::Cast(cast) => strict_null_passthrough(cast.expr.as_ref(), join_cols), + Expr::TryCast(try_cast) => { + strict_null_passthrough(try_cast.expr.as_ref(), join_cols) + } + Expr::Negative(expr) => strict_null_passthrough(expr.as_ref(), join_cols), + Expr::Like(like) | Expr::SimilarTo(like) => { + let value = syntactic_null_substitution_value(like.expr.as_ref(), join_cols); + let pattern = + syntactic_null_substitution_value(like.pattern.as_ref(), join_cols); + if value.is_null() || pattern.is_null() { + NullSubstitutionValue::Null + } else { + NullSubstitutionValue::Unknown + } + } + Expr::ScalarFunction(function) => { + syntactic_scalar_function_value(function.name(), &function.args, join_cols) + } + Expr::Exists { .. } + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::Placeholder(_) + | Expr::ScalarVariable(_, _) + | Expr::Unnest(_) + | Expr::GroupingSet(_) + | Expr::WindowFunction(_) => NullSubstitutionValue::Unknown, + Expr::AggregateFunction(_) => NullSubstitutionValue::Unknown, + // TODO: remove the next line after `Expr::Wildcard` is removed + #[expect(deprecated)] + Expr::Wildcard { .. } => NullSubstitutionValue::Unknown, + } +} + +fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { + if value.is_null() { + NullSubstitutionValue::Null + } else if let ScalarValue::Boolean(Some(value)) = value { + NullSubstitutionValue::Boolean(*value) + } else { + NullSubstitutionValue::NonNull + } +} + +fn strict_null_passthrough( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> NullSubstitutionValue { + if syntactic_null_substitution_value(expr, join_cols).is_null() { + NullSubstitutionValue::Null + } else { + NullSubstitutionValue::Unknown + } +} + +fn boolean_test_result( + value: NullSubstitutionValue, + default_for_null: bool, +) -> NullSubstitutionValue { + match value { + NullSubstitutionValue::Boolean(value) => NullSubstitutionValue::Boolean(value), + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(default_for_null), + NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { + NullSubstitutionValue::Unknown + } + } +} + +fn sql_not(value: NullSubstitutionValue) -> NullSubstitutionValue { + match value { + NullSubstitutionValue::Boolean(value) => NullSubstitutionValue::Boolean(!value), + NullSubstitutionValue::Null => NullSubstitutionValue::Null, + NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { + NullSubstitutionValue::Unknown + } + } +} + +fn sql_and( + left: NullSubstitutionValue, + right: NullSubstitutionValue, +) -> NullSubstitutionValue { + if matches!(left, NullSubstitutionValue::Boolean(false)) + || matches!(right, NullSubstitutionValue::Boolean(false)) + { + return NullSubstitutionValue::Boolean(false); + } + + match (left, right) { + (NullSubstitutionValue::Boolean(true), value) + | (value, NullSubstitutionValue::Boolean(true)) => value, + (NullSubstitutionValue::Null, NullSubstitutionValue::Null) => { + NullSubstitutionValue::Null + } + (NullSubstitutionValue::Null, NullSubstitutionValue::Unknown) + | (NullSubstitutionValue::Unknown, NullSubstitutionValue::Null) + | (NullSubstitutionValue::Unknown, _) + | (_, NullSubstitutionValue::Unknown) + | (NullSubstitutionValue::NonNull, _) + | (_, NullSubstitutionValue::NonNull) => NullSubstitutionValue::Unknown, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + +fn sql_or( + left: NullSubstitutionValue, + right: NullSubstitutionValue, +) -> NullSubstitutionValue { + if matches!(left, NullSubstitutionValue::Boolean(true)) + || matches!(right, NullSubstitutionValue::Boolean(true)) + { + return NullSubstitutionValue::Boolean(true); + } + + match (left, right) { + (NullSubstitutionValue::Boolean(false), value) + | (value, NullSubstitutionValue::Boolean(false)) => value, + (NullSubstitutionValue::Null, NullSubstitutionValue::Null) => { + NullSubstitutionValue::Null + } + (NullSubstitutionValue::Null, NullSubstitutionValue::Unknown) + | (NullSubstitutionValue::Unknown, NullSubstitutionValue::Null) + | (NullSubstitutionValue::Unknown, _) + | (_, NullSubstitutionValue::Unknown) + | (NullSubstitutionValue::NonNull, _) + | (_, NullSubstitutionValue::NonNull) => NullSubstitutionValue::Unknown, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + +fn syntactic_binary_value( + binary_expr: &datafusion_expr::BinaryExpr, + join_cols: &HashSet<&Column>, +) -> NullSubstitutionValue { + let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); + let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); + + match binary_expr.op { + datafusion_expr::Operator::And => sql_and(left, right), + datafusion_expr::Operator::Or => sql_or(left, right), + datafusion_expr::Operator::IsDistinctFrom => { + syntactic_is_distinct_from(left, right) + } + datafusion_expr::Operator::IsNotDistinctFrom => { + sql_not(syntactic_is_distinct_from(left, right)) + } + datafusion_expr::Operator::Eq + | datafusion_expr::Operator::NotEq + | datafusion_expr::Operator::Lt + | datafusion_expr::Operator::LtEq + | datafusion_expr::Operator::Gt + | datafusion_expr::Operator::GtEq + | datafusion_expr::Operator::Plus + | datafusion_expr::Operator::Minus + | datafusion_expr::Operator::Multiply + | datafusion_expr::Operator::Divide + | datafusion_expr::Operator::Modulo + | datafusion_expr::Operator::RegexMatch + | datafusion_expr::Operator::RegexIMatch + | datafusion_expr::Operator::RegexNotMatch + | datafusion_expr::Operator::RegexNotIMatch + | datafusion_expr::Operator::LikeMatch + | datafusion_expr::Operator::ILikeMatch + | datafusion_expr::Operator::NotLikeMatch + | datafusion_expr::Operator::NotILikeMatch + | datafusion_expr::Operator::BitwiseAnd + | datafusion_expr::Operator::BitwiseOr + | datafusion_expr::Operator::BitwiseXor + | datafusion_expr::Operator::BitwiseShiftRight + | datafusion_expr::Operator::BitwiseShiftLeft + | datafusion_expr::Operator::StringConcat + | datafusion_expr::Operator::AtArrow + | datafusion_expr::Operator::ArrowAt + | datafusion_expr::Operator::Arrow + | datafusion_expr::Operator::LongArrow + | datafusion_expr::Operator::HashArrow + | datafusion_expr::Operator::HashLongArrow + | datafusion_expr::Operator::AtAt + | datafusion_expr::Operator::IntegerDivide + | datafusion_expr::Operator::HashMinus + | datafusion_expr::Operator::AtQuestion + | datafusion_expr::Operator::Question + | datafusion_expr::Operator::QuestionAnd + | datafusion_expr::Operator::QuestionPipe + | datafusion_expr::Operator::Colon => { + if left.is_null() || right.is_null() { + NullSubstitutionValue::Null + } else { + NullSubstitutionValue::Unknown + } + } + } +} + +fn syntactic_is_distinct_from( + left: NullSubstitutionValue, + right: NullSubstitutionValue, +) -> NullSubstitutionValue { + match (left, right) { + (NullSubstitutionValue::Null, NullSubstitutionValue::Null) => { + NullSubstitutionValue::Boolean(false) + } + (NullSubstitutionValue::Null, value) | (value, NullSubstitutionValue::Null) => { + if value.is_definitely_non_null() { + NullSubstitutionValue::Boolean(true) + } else { + NullSubstitutionValue::Unknown + } + } + (NullSubstitutionValue::Boolean(left), NullSubstitutionValue::Boolean(right)) => { + NullSubstitutionValue::Boolean(left != right) + } + _ => NullSubstitutionValue::Unknown, + } +} + +fn syntactic_case_value( + case: &datafusion_expr::expr::Case, + join_cols: &HashSet<&Column>, +) -> NullSubstitutionValue { + if let Some(base_expr) = case.expr.as_deref() { + let base_value = syntactic_null_substitution_value(base_expr, join_cols); + let mut saw_indeterminate_comparison = false; + if base_value.is_null() { + return case + .else_expr + .as_deref() + .map_or(NullSubstitutionValue::Null, |expr| { + syntactic_null_substitution_value(expr, join_cols) + }); + } + + for (when_expr, then_expr) in &case.when_then_expr { + let when_value = syntactic_null_substitution_value(when_expr, join_cols); + match syntactic_equals(base_value, when_value) { + Some(true) => { + return syntactic_null_substitution_value( + then_expr.as_ref(), + join_cols, + ); + } + Some(false) => continue, + None => { + saw_indeterminate_comparison = true; + continue; + } + } + } + + if saw_indeterminate_comparison { + return NullSubstitutionValue::Unknown; + } + + return case + .else_expr + .as_deref() + .map_or(NullSubstitutionValue::Null, |expr| { + syntactic_null_substitution_value(expr, join_cols) + }); + } + + for (when_expr, then_expr) in &case.when_then_expr { + match syntactic_null_substitution_value(when_expr.as_ref(), join_cols) { + NullSubstitutionValue::Boolean(true) => { + return syntactic_null_substitution_value(then_expr.as_ref(), join_cols); + } + NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null => {} + NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { + return NullSubstitutionValue::Unknown; + } + } + } + + case.else_expr + .as_deref() + .map_or(NullSubstitutionValue::Null, |expr| { + syntactic_null_substitution_value(expr, join_cols) + }) +} + +fn syntactic_equals( + left: NullSubstitutionValue, + right: NullSubstitutionValue, +) -> Option { + match (left, right) { + (NullSubstitutionValue::Null, _) | (_, NullSubstitutionValue::Null) => None, + (NullSubstitutionValue::Boolean(left), NullSubstitutionValue::Boolean(right)) => { + Some(left == right) + } + _ => None, + } +} + +fn syntactic_scalar_function_value( + name: &str, + args: &[Expr], + join_cols: &HashSet<&Column>, +) -> NullSubstitutionValue { + if name.eq_ignore_ascii_case("coalesce") { + return syntactic_coalesce_value(args, join_cols); + } + + let arg_values = args + .iter() + .map(|expr| syntactic_null_substitution_value(expr, join_cols)) + .collect::>(); + + if arg_values.iter().any(|value| value.is_null()) + && is_null_propagating_scalar_function(name) + { + NullSubstitutionValue::Null + } else { + NullSubstitutionValue::Unknown + } +} + +fn syntactic_coalesce_value( + args: &[Expr], + join_cols: &HashSet<&Column>, +) -> NullSubstitutionValue { + let mut saw_unknown = false; + + for expr in args { + match syntactic_null_substitution_value(expr, join_cols) { + NullSubstitutionValue::Null => continue, + NullSubstitutionValue::NonNull => return NullSubstitutionValue::NonNull, + NullSubstitutionValue::Boolean(value) => { + return NullSubstitutionValue::Boolean(value); + } + NullSubstitutionValue::Unknown => saw_unknown = true, + } + } + + if saw_unknown { + NullSubstitutionValue::Unknown + } else { + NullSubstitutionValue::Null + } +} + +fn is_null_propagating_scalar_function(name: &str) -> bool { + matches!( + name, + "btrim" + | "char_length" + | "length" + | "lower" + | "regexp_like" + | "regexp_replace" + | "to_timestamp" + | "trim" + | "upper" + ) +} + /// Determines if an expression will always evaluate to null. /// `c0 + 8` return true /// `c0 IS NULL` return false @@ -177,7 +674,9 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_expr::{Operator, binary_expr, case, col, in_list, is_null, lit}; + use datafusion_expr::{ + Operator, binary_expr, case, col, in_list, is_null, lit, when, + }; #[test] fn expr_is_restrict_null_predicate() -> Result<()> { @@ -201,8 +700,6 @@ mod tests { (binary_expr(col("a"), Operator::Gt, lit(8i64)), true), // a <= 8 (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true), - // a > b (b is outside join key set) - (binary_expr(col("a"), Operator::Gt, col("b")), false), // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END ( case(col("a")) @@ -218,6 +715,27 @@ mod tests { .otherwise(lit(false))?, true, ), + // CASE 1 WHEN 1 THEN true ELSE false END + ( + case(lit(1i64)) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + false, + ), + // CASE 1 WHEN 1 THEN NULL ELSE false END + ( + case(lit(1i64)) + .when(lit(1i64), lit(ScalarValue::Null)) + .otherwise(lit(false))?, + true, + ), + // CASE true WHEN true THEN false ELSE true END + ( + case(lit(true)) + .when(lit(true), lit(false)) + .otherwise(lit(true))?, + true, + ), // CASE a WHEN 0 THEN false ELSE true END ( case(col("a")) @@ -271,16 +789,35 @@ mod tests { in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true), true, ), + // CASE WHEN a IS NOT NULL THEN a ELSE b END > 2 + ( + binary_expr( + when(Expr::IsNotNull(Box::new(col("a"))), col("a")) + .otherwise(col("b"))?, + Operator::Gt, + lit(2i64), + ), + true, + ), ]; - let column_a = Column::from_name("a"); for (predicate, expected) in test_cases { - let join_cols_of_predicate = std::iter::once(&column_a); - let actual = - is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; + let join_cols_of_predicate = predicate.column_refs(); + let actual = is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + )?; assert_eq!(actual, expected, "{predicate}"); } + // Keep coverage for the fast path that rejects predicates referencing + // columns outside the provided join key set. + let predicate = binary_expr(col("a"), Operator::Gt, col("b")); + let column_a = Column::from_name("a"); + let actual = + is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a))?; + assert!(!actual, "{predicate}"); + Ok(()) } } From 515da96878498202e6809465c5ad4f4e5f26c1a3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 17:13:16 +0800 Subject: [PATCH 08/63] Add regression tests for SQL shape coverage Implement focused regression tests in push_down_filter_regressions.rs. These tests cover the following SQL shapes: window + scalar subquery, regr_* aggregate query, correlated IN subquery, and NATURAL JOIN with UNION ALL cases. This ensures accurate reporting and enhances SQL handling reliability. --- datafusion/core/tests/sql/mod.rs | 1 + .../tests/sql/push_down_filter_regressions.rs | 170 ++++++++++++++++++ 2 files changed, 171 insertions(+) create mode 100644 datafusion/core/tests/sql/push_down_filter_regressions.rs diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 9a1dc5502ee60..a245227382d83 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -67,6 +67,7 @@ pub mod create_drop; pub mod explain_analyze; pub mod joins; mod path_partition; +mod push_down_filter_regressions; mod runtime_config; pub mod select; mod sql_api; diff --git a/datafusion/core/tests/sql/push_down_filter_regressions.rs b/datafusion/core/tests/sql/push_down_filter_regressions.rs new file mode 100644 index 0000000000000..1069e75a26f37 --- /dev/null +++ b/datafusion/core/tests/sql/push_down_filter_regressions.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use super::*; +use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; + +#[tokio::test] +async fn window_scalar_subquery_regression() -> Result<()> { + let ctx = SessionContext::new(); + let sql = r#" + WITH suppliers AS ( + SELECT * + FROM (VALUES (1, 10.0), (1, 20.0)) AS t(nation, acctbal) + ) + SELECT + ROW_NUMBER() OVER (PARTITION BY nation ORDER BY acctbal DESC) AS rn + FROM suppliers AS s + WHERE acctbal > ( + SELECT AVG(acctbal) FROM suppliers + ) + "#; + + let results = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + &["+----+", "| rn |", "+----+", "| 1 |", "+----+",], + &results + ); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_regr_functions_regression() -> Result<()> { + let ctx = SessionContext::new(); + let batch = RecordBatch::try_from_iter(vec![ + ( + "c11", + Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])) as ArrayRef, + ), + ( + "c12", + Arc::new(Float64Array::from(vec![2.0, 4.0, 6.0])) as ArrayRef, + ), + ])?; + ctx.register_batch("aggregate_test_100", batch)?; + + let sql = r#" + select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) + from aggregate_test_100 + "#; + + let rows = execute(&ctx, sql).await; + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].len(), 9); + assert!(rows[0].iter().all(|value| value != "NULL")); + + Ok(()) +} + +#[tokio::test] +async fn correlated_in_subquery_regression() -> Result<()> { + let ctx = SessionContext::new(); + let t1 = RecordBatch::try_from_iter(vec![ + ("t1_id", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ( + "t1_name", + Arc::new(StringArray::from(vec!["alpha", "beta"])) as ArrayRef, + ), + ("t1_int", Arc::new(Int32Array::from(vec![1, 0])) as ArrayRef), + ])?; + let t2 = RecordBatch::try_from_iter(vec![( + "t2_id", + Arc::new(Int32Array::from(vec![12, 99])) as ArrayRef, + )])?; + ctx.register_batch("t1", t1)?; + ctx.register_batch("t2", t2)?; + + let sql = r#" + select t1.t1_id, + t1.t1_name, + t1.t1_int + from t1 + where t1.t1_id + 12 in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) + "#; + + let results = ctx.sql(sql).await?.collect().await?; + + assert_batches_sorted_eq!( + &[ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 1 | alpha | 1 |", + "+-------+---------+--------+", + ], + &results + ); + + Ok(()) +} + +#[tokio::test] +async fn natural_join_union_regression() -> Result<()> { + let ctx = SessionContext::new(); + let t1 = RecordBatch::try_from_iter(vec![ + ("v0", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ( + "v2", + Arc::new(Int32Array::from(vec![None, Some(5)])) as ArrayRef, + ), + ])?; + // Keep `v2` only on the left side so the natural join key remains `v0`. + let t2 = RecordBatch::try_from_iter(vec![( + "v0", + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )])?; + ctx.register_batch("t1", t1)?; + ctx.register_batch("t2", t2)?; + + let sql = r#" + SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 + UNION ALL + SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 WHERE (t1.v2 IS NULL) + "#; + + let results = ctx.sql(sql).await?.collect().await?; + + assert_batches_sorted_eq!( + &[ + "+----+----+", + "| v2 | v0 |", + "+----+----+", + "| | 1 |", + "| | 1 |", + "| 5 | 2 |", + "+----+----+", + ], + &results + ); + + Ok(()) +} From 115b9f276f5c940ce84c04f38e463bc85f010c1e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 17:17:28 +0800 Subject: [PATCH 09/63] Refactor null predicate evaluation and add tests Extract authoritative_restrict_null_predicate() to directly access the old physical-evaluation path. Implement a debug check in is_restrict_null_predicate() to ensure fast-path results match the authoritative evaluator. Add a unit test to verify consistency between both paths, enhancing debug capabilities and test coverage. --- datafusion/optimizer/src/utils.rs | 119 ++++++++++++++++++++++++++---- 1 file changed, 103 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 33d1f9d01d0ac..e10b7d2998d7e 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -95,27 +95,23 @@ pub fn is_restrict_null_predicate<'a>( if let Some(is_restricting) = syntactic_restrict_null_predicate(&predicate, &join_cols) { + #[cfg(debug_assertions)] + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); + } return Ok(is_restricting); } // If result is single `true`, return false; // If result is single `NULL` or `false`, return true; - Ok( - match evaluate_expr_with_null_column(predicate, join_cols.into_iter())? { - ColumnarValue::Array(array) => { - if array.len() == 1 { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) - } else { - false - } - } - ColumnarValue::Scalar(scalar) => matches!( - scalar, - ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) - ), - }, - ) + authoritative_restrict_null_predicate(predicate, join_cols.into_iter()) } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -666,6 +662,28 @@ fn evaluate_expr_with_null_column<'a>( .evaluate(&input_batch) } +fn authoritative_restrict_null_predicate<'a>( + predicate: Expr, + join_cols_of_predicate: impl IntoIterator, +) -> Result { + Ok( + match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) + ), + }, + ) +} + fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; expr.rewrite(&mut expr_rewrite).data() @@ -820,4 +838,73 @@ mod tests { Ok(()) } + + #[test] + fn syntactic_fast_path_matches_authoritative_evaluator() -> Result<()> { + let test_cases = vec![ + is_null(col("a")), + Expr::IsNotNull(Box::new(col("a"))), + binary_expr(col("a"), Operator::Gt, lit(8i64)), + binary_expr(col("a"), Operator::Eq, lit(ScalarValue::Null)), + binary_expr(col("a"), Operator::And, lit(true)), + binary_expr(col("a"), Operator::Or, lit(false)), + Expr::Not(Box::new(col("a").is_true())), + col("a").is_true(), + col("a").is_false(), + col("a").is_unknown(), + col("a").is_not_true(), + col("a").is_not_false(), + col("a").is_not_unknown(), + col("a").between(lit(1i64), lit(10i64)), + binary_expr( + when(Expr::IsNotNull(Box::new(col("a"))), col("a")) + .otherwise(col("b"))?, + Operator::Gt, + lit(2i64), + ), + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + binary_expr( + case(lit(1i64)) + .when(lit(1i64), lit(ScalarValue::Null)) + .otherwise(lit(false))?, + Operator::IsNotDistinctFrom, + lit(true), + ), + ]; + + for predicate in test_cases { + let join_cols = predicate.column_refs(); + if let Some(syntactic) = + syntactic_restrict_null_predicate(&predicate, &join_cols) + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + ) + .unwrap_or_else(|error| { + panic!( + "authoritative evaluator failed for predicate `{predicate}`: {error}" + ) + }); + assert_eq!( + syntactic, authoritative, + "syntactic fast path disagrees with authoritative evaluator for predicate: {predicate}", + ); + } + } + + Ok(()) + } } From 38680f5327b8e0d67622e643d200f2bbe7c5066e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 17:24:28 +0800 Subject: [PATCH 10/63] Refactor null handling in expression evaluation Remove early-return handling for complex constructs like CASE and scalar functions. Narrow fast path to a limited subset including aliases, literals, direct join-key column substitution, and specific operators. All other cases now utilize evaluate_expr_with_null_column(...) for authoritative evaluation. --- datafusion/optimizer/src/utils.rs | 407 +++++++----------------------- 1 file changed, 93 insertions(+), 314 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index e10b7d2998d7e..e46a3de2f2762 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -116,30 +116,21 @@ pub fn is_restrict_null_predicate<'a>( #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum NullSubstitutionValue { - Unknown, Null, NonNull, Boolean(bool), } -impl NullSubstitutionValue { - fn is_null(self) -> bool { - matches!(self, Self::Null) - } - - fn is_definitely_non_null(self) -> bool { - matches!(self, Self::NonNull | Self::Boolean(_)) - } -} - fn syntactic_restrict_null_predicate( predicate: &Expr, join_cols: &HashSet<&Column>, ) -> Option { match syntactic_null_substitution_value(predicate, join_cols) { - NullSubstitutionValue::Boolean(true) => Some(false), - NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null => Some(true), - NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => None, + Some(NullSubstitutionValue::Boolean(true)) => Some(false), + Some(NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null) => { + Some(true) + } + Some(NullSubstitutionValue::NonNull) | None => None, } } @@ -156,87 +147,43 @@ fn predicate_uses_only_columns( fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, -) -> NullSubstitutionValue { +) -> Option { match expr { Expr::Alias(alias) => { syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) } Expr::Column(column) => { if join_cols.contains(column) { - NullSubstitutionValue::Null + Some(NullSubstitutionValue::Null) } else { - NullSubstitutionValue::Unknown + None } } - Expr::Literal(value, _) => scalar_to_null_substitution_value(value), + Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), Expr::Not(expr) => { sql_not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) } Expr::IsNull(expr) => { match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(true), - NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { - NullSubstitutionValue::Boolean(false) + Some(NullSubstitutionValue::Null) => { + Some(NullSubstitutionValue::Boolean(true)) } - NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, + Some( + NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_), + ) => Some(NullSubstitutionValue::Boolean(false)), + None => None, } } Expr::IsNotNull(expr) => { match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(false), - NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { - NullSubstitutionValue::Boolean(true) + Some(NullSubstitutionValue::Null) => { + Some(NullSubstitutionValue::Boolean(false)) } - NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, - } - } - Expr::IsTrue(expr) => boolean_test_result( - syntactic_null_substitution_value(expr.as_ref(), join_cols), - false, - ), - Expr::IsFalse(expr) => { - match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - NullSubstitutionValue::Boolean(value) => { - NullSubstitutionValue::Boolean(!value) - } - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(false), - NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { - NullSubstitutionValue::Unknown - } - } - } - Expr::IsUnknown(expr) => { - match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(true), - NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { - NullSubstitutionValue::Boolean(false) - } - NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, - } - } - Expr::IsNotTrue(expr) => boolean_test_result( - syntactic_null_substitution_value(expr.as_ref(), join_cols), - true, - ), - Expr::IsNotFalse(expr) => { - match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - NullSubstitutionValue::Boolean(value) => { - NullSubstitutionValue::Boolean(value) - } - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(true), - NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { - NullSubstitutionValue::Unknown - } - } - } - Expr::IsNotUnknown(expr) => { - match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(false), - NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_) => { - NullSubstitutionValue::Boolean(true) - } - NullSubstitutionValue::Unknown => NullSubstitutionValue::Unknown, + Some( + NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_), + ) => Some(NullSubstitutionValue::Boolean(true)), + None => None, } } Expr::Between(between) => { @@ -245,13 +192,15 @@ fn syntactic_null_substitution_value( let low = syntactic_null_substitution_value(between.low.as_ref(), join_cols); let high = syntactic_null_substitution_value(between.high.as_ref(), join_cols); - if expr.is_null() || low.is_null() || high.is_null() { - NullSubstitutionValue::Null + if matches!(expr, Some(NullSubstitutionValue::Null)) + || matches!(low, Some(NullSubstitutionValue::Null)) + || matches!(high, Some(NullSubstitutionValue::Null)) + { + Some(NullSubstitutionValue::Null) } else { - NullSubstitutionValue::Unknown + None } } - Expr::Case(case) => syntactic_case_value(case, join_cols), Expr::Cast(cast) => strict_null_passthrough(cast.expr.as_ref(), join_cols), Expr::TryCast(try_cast) => { strict_null_passthrough(try_cast.expr.as_ref(), join_cols) @@ -261,15 +210,14 @@ fn syntactic_null_substitution_value( let value = syntactic_null_substitution_value(like.expr.as_ref(), join_cols); let pattern = syntactic_null_substitution_value(like.pattern.as_ref(), join_cols); - if value.is_null() || pattern.is_null() { - NullSubstitutionValue::Null + if matches!(value, Some(NullSubstitutionValue::Null)) + || matches!(pattern, Some(NullSubstitutionValue::Null)) + { + Some(NullSubstitutionValue::Null) } else { - NullSubstitutionValue::Unknown + None } } - Expr::ScalarFunction(function) => { - syntactic_scalar_function_value(function.name(), &function.args, join_cols) - } Expr::Exists { .. } | Expr::InList(_) | Expr::InSubquery(_) @@ -280,11 +228,19 @@ fn syntactic_null_substitution_value( | Expr::ScalarVariable(_, _) | Expr::Unnest(_) | Expr::GroupingSet(_) - | Expr::WindowFunction(_) => NullSubstitutionValue::Unknown, - Expr::AggregateFunction(_) => NullSubstitutionValue::Unknown, + | Expr::WindowFunction(_) + | Expr::ScalarFunction(_) + | Expr::Case(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) => None, + Expr::AggregateFunction(_) => None, // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] - Expr::Wildcard { .. } => NullSubstitutionValue::Unknown, + Expr::Wildcard { .. } => None, } } @@ -301,59 +257,47 @@ fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionVal fn strict_null_passthrough( expr: &Expr, join_cols: &HashSet<&Column>, -) -> NullSubstitutionValue { - if syntactic_null_substitution_value(expr, join_cols).is_null() { - NullSubstitutionValue::Null +) -> Option { + if matches!( + syntactic_null_substitution_value(expr, join_cols), + Some(NullSubstitutionValue::Null) + ) { + Some(NullSubstitutionValue::Null) } else { - NullSubstitutionValue::Unknown - } -} - -fn boolean_test_result( - value: NullSubstitutionValue, - default_for_null: bool, -) -> NullSubstitutionValue { - match value { - NullSubstitutionValue::Boolean(value) => NullSubstitutionValue::Boolean(value), - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(default_for_null), - NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { - NullSubstitutionValue::Unknown - } + None } } -fn sql_not(value: NullSubstitutionValue) -> NullSubstitutionValue { +fn sql_not(value: Option) -> Option { match value { - NullSubstitutionValue::Boolean(value) => NullSubstitutionValue::Boolean(!value), - NullSubstitutionValue::Null => NullSubstitutionValue::Null, - NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { - NullSubstitutionValue::Unknown + Some(NullSubstitutionValue::Boolean(value)) => { + Some(NullSubstitutionValue::Boolean(!value)) } + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), + Some(NullSubstitutionValue::NonNull) | None => None, } } fn sql_and( - left: NullSubstitutionValue, - right: NullSubstitutionValue, -) -> NullSubstitutionValue { - if matches!(left, NullSubstitutionValue::Boolean(false)) - || matches!(right, NullSubstitutionValue::Boolean(false)) + left: Option, + right: Option, +) -> Option { + if matches!(left, Some(NullSubstitutionValue::Boolean(false))) + || matches!(right, Some(NullSubstitutionValue::Boolean(false))) { - return NullSubstitutionValue::Boolean(false); + return Some(NullSubstitutionValue::Boolean(false)); } match (left, right) { - (NullSubstitutionValue::Boolean(true), value) - | (value, NullSubstitutionValue::Boolean(true)) => value, - (NullSubstitutionValue::Null, NullSubstitutionValue::Null) => { - NullSubstitutionValue::Null - } - (NullSubstitutionValue::Null, NullSubstitutionValue::Unknown) - | (NullSubstitutionValue::Unknown, NullSubstitutionValue::Null) - | (NullSubstitutionValue::Unknown, _) - | (_, NullSubstitutionValue::Unknown) - | (NullSubstitutionValue::NonNull, _) - | (_, NullSubstitutionValue::NonNull) => NullSubstitutionValue::Unknown, + (Some(NullSubstitutionValue::Boolean(true)), value) + | (value, Some(NullSubstitutionValue::Boolean(true))) => value, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, (left, right) => { debug_assert_eq!(left, right); left @@ -362,27 +306,25 @@ fn sql_and( } fn sql_or( - left: NullSubstitutionValue, - right: NullSubstitutionValue, -) -> NullSubstitutionValue { - if matches!(left, NullSubstitutionValue::Boolean(true)) - || matches!(right, NullSubstitutionValue::Boolean(true)) + left: Option, + right: Option, +) -> Option { + if matches!(left, Some(NullSubstitutionValue::Boolean(true))) + || matches!(right, Some(NullSubstitutionValue::Boolean(true))) { - return NullSubstitutionValue::Boolean(true); + return Some(NullSubstitutionValue::Boolean(true)); } match (left, right) { - (NullSubstitutionValue::Boolean(false), value) - | (value, NullSubstitutionValue::Boolean(false)) => value, - (NullSubstitutionValue::Null, NullSubstitutionValue::Null) => { - NullSubstitutionValue::Null - } - (NullSubstitutionValue::Null, NullSubstitutionValue::Unknown) - | (NullSubstitutionValue::Unknown, NullSubstitutionValue::Null) - | (NullSubstitutionValue::Unknown, _) - | (_, NullSubstitutionValue::Unknown) - | (NullSubstitutionValue::NonNull, _) - | (_, NullSubstitutionValue::NonNull) => NullSubstitutionValue::Unknown, + (Some(NullSubstitutionValue::Boolean(false)), value) + | (value, Some(NullSubstitutionValue::Boolean(false))) => value, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, (left, right) => { debug_assert_eq!(left, right); left @@ -393,19 +335,13 @@ fn sql_or( fn syntactic_binary_value( binary_expr: &datafusion_expr::BinaryExpr, join_cols: &HashSet<&Column>, -) -> NullSubstitutionValue { +) -> Option { let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); match binary_expr.op { datafusion_expr::Operator::And => sql_and(left, right), datafusion_expr::Operator::Or => sql_or(left, right), - datafusion_expr::Operator::IsDistinctFrom => { - syntactic_is_distinct_from(left, right) - } - datafusion_expr::Operator::IsNotDistinctFrom => { - sql_not(syntactic_is_distinct_from(left, right)) - } datafusion_expr::Operator::Eq | datafusion_expr::Operator::NotEq | datafusion_expr::Operator::Lt @@ -445,176 +381,19 @@ fn syntactic_binary_value( | datafusion_expr::Operator::QuestionAnd | datafusion_expr::Operator::QuestionPipe | datafusion_expr::Operator::Colon => { - if left.is_null() || right.is_null() { - NullSubstitutionValue::Null - } else { - NullSubstitutionValue::Unknown - } - } - } -} - -fn syntactic_is_distinct_from( - left: NullSubstitutionValue, - right: NullSubstitutionValue, -) -> NullSubstitutionValue { - match (left, right) { - (NullSubstitutionValue::Null, NullSubstitutionValue::Null) => { - NullSubstitutionValue::Boolean(false) - } - (NullSubstitutionValue::Null, value) | (value, NullSubstitutionValue::Null) => { - if value.is_definitely_non_null() { - NullSubstitutionValue::Boolean(true) + if matches!(left, Some(NullSubstitutionValue::Null)) + || matches!(right, Some(NullSubstitutionValue::Null)) + { + Some(NullSubstitutionValue::Null) } else { - NullSubstitutionValue::Unknown + None } } - (NullSubstitutionValue::Boolean(left), NullSubstitutionValue::Boolean(right)) => { - NullSubstitutionValue::Boolean(left != right) - } - _ => NullSubstitutionValue::Unknown, + datafusion_expr::Operator::IsDistinctFrom + | datafusion_expr::Operator::IsNotDistinctFrom => None, } } -fn syntactic_case_value( - case: &datafusion_expr::expr::Case, - join_cols: &HashSet<&Column>, -) -> NullSubstitutionValue { - if let Some(base_expr) = case.expr.as_deref() { - let base_value = syntactic_null_substitution_value(base_expr, join_cols); - let mut saw_indeterminate_comparison = false; - if base_value.is_null() { - return case - .else_expr - .as_deref() - .map_or(NullSubstitutionValue::Null, |expr| { - syntactic_null_substitution_value(expr, join_cols) - }); - } - - for (when_expr, then_expr) in &case.when_then_expr { - let when_value = syntactic_null_substitution_value(when_expr, join_cols); - match syntactic_equals(base_value, when_value) { - Some(true) => { - return syntactic_null_substitution_value( - then_expr.as_ref(), - join_cols, - ); - } - Some(false) => continue, - None => { - saw_indeterminate_comparison = true; - continue; - } - } - } - - if saw_indeterminate_comparison { - return NullSubstitutionValue::Unknown; - } - - return case - .else_expr - .as_deref() - .map_or(NullSubstitutionValue::Null, |expr| { - syntactic_null_substitution_value(expr, join_cols) - }); - } - - for (when_expr, then_expr) in &case.when_then_expr { - match syntactic_null_substitution_value(when_expr.as_ref(), join_cols) { - NullSubstitutionValue::Boolean(true) => { - return syntactic_null_substitution_value(then_expr.as_ref(), join_cols); - } - NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null => {} - NullSubstitutionValue::Unknown | NullSubstitutionValue::NonNull => { - return NullSubstitutionValue::Unknown; - } - } - } - - case.else_expr - .as_deref() - .map_or(NullSubstitutionValue::Null, |expr| { - syntactic_null_substitution_value(expr, join_cols) - }) -} - -fn syntactic_equals( - left: NullSubstitutionValue, - right: NullSubstitutionValue, -) -> Option { - match (left, right) { - (NullSubstitutionValue::Null, _) | (_, NullSubstitutionValue::Null) => None, - (NullSubstitutionValue::Boolean(left), NullSubstitutionValue::Boolean(right)) => { - Some(left == right) - } - _ => None, - } -} - -fn syntactic_scalar_function_value( - name: &str, - args: &[Expr], - join_cols: &HashSet<&Column>, -) -> NullSubstitutionValue { - if name.eq_ignore_ascii_case("coalesce") { - return syntactic_coalesce_value(args, join_cols); - } - - let arg_values = args - .iter() - .map(|expr| syntactic_null_substitution_value(expr, join_cols)) - .collect::>(); - - if arg_values.iter().any(|value| value.is_null()) - && is_null_propagating_scalar_function(name) - { - NullSubstitutionValue::Null - } else { - NullSubstitutionValue::Unknown - } -} - -fn syntactic_coalesce_value( - args: &[Expr], - join_cols: &HashSet<&Column>, -) -> NullSubstitutionValue { - let mut saw_unknown = false; - - for expr in args { - match syntactic_null_substitution_value(expr, join_cols) { - NullSubstitutionValue::Null => continue, - NullSubstitutionValue::NonNull => return NullSubstitutionValue::NonNull, - NullSubstitutionValue::Boolean(value) => { - return NullSubstitutionValue::Boolean(value); - } - NullSubstitutionValue::Unknown => saw_unknown = true, - } - } - - if saw_unknown { - NullSubstitutionValue::Unknown - } else { - NullSubstitutionValue::Null - } -} - -fn is_null_propagating_scalar_function(name: &str) -> bool { - matches!( - name, - "btrim" - | "char_length" - | "length" - | "lower" - | "regexp_like" - | "regexp_replace" - | "to_timestamp" - | "trim" - | "upper" - ) -} - /// Determines if an expression will always evaluate to null. /// `c0 + 8` return true /// `c0 IS NULL` return false From ebd70ef6e714c30873be1dd2fb38b5f5f1159fed Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 22:08:10 +0800 Subject: [PATCH 11/63] Test window scalar subquery optimizer delta Add window_scalar_subquery_optimizer_delta to compare query execution plans with different push_down_filter settings. Implement test hook for easier comparison without manual edits. Confirm that optimizer changes the plan from Filter + Cross Join to Inner Join, indicating a key delta for this case. --- .../tests/sql/push_down_filter_regressions.rs | 92 ++++++++++++++++--- datafusion/optimizer/src/utils.rs | 77 +++++++++++++--- 2 files changed, 141 insertions(+), 28 deletions(-) diff --git a/datafusion/core/tests/sql/push_down_filter_regressions.rs b/datafusion/core/tests/sql/push_down_filter_regressions.rs index 1069e75a26f37..cb04a4a102bc2 100644 --- a/datafusion/core/tests/sql/push_down_filter_regressions.rs +++ b/datafusion/core/tests/sql/push_down_filter_regressions.rs @@ -19,24 +19,52 @@ use std::sync::Arc; use super::*; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; +use datafusion_optimizer::utils::{ + NullRestrictionEvalMode, set_null_restriction_eval_mode_for_test, +}; + +const WINDOW_SCALAR_SUBQUERY_SQL: &str = r#" + WITH suppliers AS ( + SELECT * + FROM (VALUES (1, 10.0), (1, 20.0)) AS t(nation, acctbal) + ) + SELECT + ROW_NUMBER() OVER (PARTITION BY nation ORDER BY acctbal DESC) AS rn + FROM suppliers AS s + WHERE acctbal > ( + SELECT AVG(acctbal) FROM suppliers + ) +"#; + +fn sqllogictest_style_ctx(push_down_filter_enabled: bool) -> SessionContext { + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(4)); + if !push_down_filter_enabled { + assert!(ctx.remove_optimizer_rule("push_down_filter")); + } + ctx +} + +async fn capture_window_scalar_subquery_plans( + push_down_filter_enabled: bool, + null_restriction_eval_mode: NullRestrictionEvalMode, +) -> Result<(String, String)> { + let _mode_guard = set_null_restriction_eval_mode_for_test(null_restriction_eval_mode); + let ctx = sqllogictest_style_ctx(push_down_filter_enabled); + let df = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?; + let optimized_plan = df.clone().into_optimized_plan()?; + let physical_plan = df.create_physical_plan().await?; + + Ok(( + optimized_plan.display_indent_schema().to_string(), + displayable(physical_plan.as_ref()).indent(true).to_string(), + )) +} #[tokio::test] async fn window_scalar_subquery_regression() -> Result<()> { let ctx = SessionContext::new(); - let sql = r#" - WITH suppliers AS ( - SELECT * - FROM (VALUES (1, 10.0), (1, 20.0)) AS t(nation, acctbal) - ) - SELECT - ROW_NUMBER() OVER (PARTITION BY nation ORDER BY acctbal DESC) AS rn - FROM suppliers AS s - WHERE acctbal > ( - SELECT AVG(acctbal) FROM suppliers - ) - "#; - - let results = ctx.sql(sql).await?.collect().await?; + let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?; assert_batches_eq!( &["+----+", "| rn |", "+----+", "| 1 |", "+----+",], @@ -168,3 +196,39 @@ async fn natural_join_union_regression() -> Result<()> { Ok(()) } + +#[tokio::test(flavor = "current_thread")] +async fn window_scalar_subquery_optimizer_delta() -> Result<()> { + let (enabled_optimized, enabled_physical) = + capture_window_scalar_subquery_plans(true, NullRestrictionEvalMode::Auto).await?; + let (disabled_optimized, disabled_physical) = + capture_window_scalar_subquery_plans(false, NullRestrictionEvalMode::Auto) + .await?; + let (authoritative_optimized, authoritative_physical) = + capture_window_scalar_subquery_plans( + true, + NullRestrictionEvalMode::AuthoritativeOnly, + ) + .await?; + + assert!(enabled_optimized.contains( + "Inner Join: Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)" + )); + assert!( + disabled_optimized + .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") + ); + assert!(disabled_optimized.contains("Cross Join:")); + + assert!(enabled_physical.contains("NestedLoopJoinExec: join_type=Inner")); + assert!(enabled_physical.contains("filter=acctbal@0 > avg(suppliers.acctbal)@1")); + assert!( + disabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") + ); + assert!(disabled_physical.contains("CrossJoinExec")); + + assert_eq!(authoritative_optimized, enabled_optimized); + assert_eq!(authoritative_physical, enabled_physical); + + Ok(()) +} diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index e46a3de2f2762..e02a5206c3e81 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,6 +18,8 @@ //! Utility functions leveraged by the query optimizer rules use std::collections::{BTreeSet, HashMap, HashSet}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, Mutex, MutexGuard, OnceLock}; use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::array::{Array, RecordBatch, new_null_array}; @@ -30,7 +32,6 @@ use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan}; use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; -use std::sync::Arc; /// Re-export of `NamesPreserver` for backwards compatibility, /// as it was initially placed here and then moved elsewhere. @@ -68,6 +69,52 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { trace!("{description}::\n{}\n", plan.display_indent_schema()); } +#[doc(hidden)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NullRestrictionEvalMode { + Auto = 0, + AuthoritativeOnly = 1, +} + +static NULL_RESTRICTION_EVAL_MODE: AtomicU8 = + AtomicU8::new(NullRestrictionEvalMode::Auto as u8); +static NULL_RESTRICTION_EVAL_MODE_LOCK: OnceLock> = OnceLock::new(); + +fn null_restriction_eval_mode() -> NullRestrictionEvalMode { + match NULL_RESTRICTION_EVAL_MODE.load(Ordering::Relaxed) { + 1 => NullRestrictionEvalMode::AuthoritativeOnly, + _ => NullRestrictionEvalMode::Auto, + } +} + +#[doc(hidden)] +pub struct NullRestrictionEvalModeGuard { + previous_mode: NullRestrictionEvalMode, + _lock: MutexGuard<'static, ()>, +} + +impl Drop for NullRestrictionEvalModeGuard { + fn drop(&mut self) { + NULL_RESTRICTION_EVAL_MODE.store(self.previous_mode as u8, Ordering::Relaxed); + } +} + +#[doc(hidden)] +pub fn set_null_restriction_eval_mode_for_test( + mode: NullRestrictionEvalMode, +) -> NullRestrictionEvalModeGuard { + let lock = NULL_RESTRICTION_EVAL_MODE_LOCK + .get_or_init(|| Mutex::new(())) + .lock() + .expect("null restriction mode lock poisoned"); + let previous_mode = null_restriction_eval_mode(); + NULL_RESTRICTION_EVAL_MODE.store(mode as u8, Ordering::Relaxed); + NullRestrictionEvalModeGuard { + previous_mode, + _lock: lock, + } +} + /// Determine whether a predicate can restrict NULLs. e.g. /// `c0 > 8` return true; /// `c0 IS NULL` return false. @@ -92,21 +139,23 @@ pub fn is_restrict_null_predicate<'a>( return Ok(false); } - if let Some(is_restricting) = - syntactic_restrict_null_predicate(&predicate, &join_cols) - { - #[cfg(debug_assertions)] + if null_restriction_eval_mode() == NullRestrictionEvalMode::Auto { + if let Some(is_restricting) = + syntactic_restrict_null_predicate(&predicate, &join_cols) { - let authoritative = authoritative_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - )?; - debug_assert_eq!( - is_restricting, authoritative, - "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" - ); + #[cfg(debug_assertions)] + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); + } + return Ok(is_restricting); } - return Ok(is_restricting); } // If result is single `true`, return false; From 9665c636d0799a81fb0b27f289f20b279fe0508c Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 22:17:51 +0800 Subject: [PATCH 12/63] Avoid join filter rewrite for scalar subqueries Prevent conversion of post-join filters to join filters when the join lacks equijoin keys and one side is scalar. This change maintains the filter + cross join structure for failing scalar-subquery shapes. Added tests for optimizer-level guards and regression coverage. --- .../tests/sql/push_down_filter_regressions.rs | 29 +++++++++++--- datafusion/optimizer/src/push_down_filter.rs | 38 ++++++++++++++++++- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/sql/push_down_filter_regressions.rs b/datafusion/core/tests/sql/push_down_filter_regressions.rs index cb04a4a102bc2..d683fd068c797 100644 --- a/datafusion/core/tests/sql/push_down_filter_regressions.rs +++ b/datafusion/core/tests/sql/push_down_filter_regressions.rs @@ -74,6 +74,19 @@ async fn window_scalar_subquery_regression() -> Result<()> { Ok(()) } +#[tokio::test] +async fn window_scalar_subquery_sqllogictest_style_regression() -> Result<()> { + let ctx = sqllogictest_style_ctx(true); + let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?; + + assert_batches_eq!( + &["+----+", "| rn |", "+----+", "| 1 |", "+----+",], + &results + ); + + Ok(()) +} + #[tokio::test] async fn aggregate_regr_functions_regression() -> Result<()> { let ctx = SessionContext::new(); @@ -211,22 +224,28 @@ async fn window_scalar_subquery_optimizer_delta() -> Result<()> { ) .await?; - assert!(enabled_optimized.contains( - "Inner Join: Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)" - )); + assert!( + enabled_optimized + .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") + ); + assert!(enabled_optimized.contains("Cross Join:")); assert!( disabled_optimized .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") ); assert!(disabled_optimized.contains("Cross Join:")); - assert!(enabled_physical.contains("NestedLoopJoinExec: join_type=Inner")); - assert!(enabled_physical.contains("filter=acctbal@0 > avg(suppliers.acctbal)@1")); + assert!( + enabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") + ); + assert!(enabled_physical.contains("CrossJoinExec")); assert!( disabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") ); assert!(disabled_physical.contains("CrossJoinExec")); + assert_eq!(enabled_optimized, disabled_optimized); + assert_eq!(enabled_physical, disabled_physical); assert_eq!(authoritative_optimized, enabled_optimized); assert_eq!(authoritative_physical, enabled_physical); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9562ab818824f..6ec374261153e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -412,6 +412,9 @@ fn push_down_all_join( on_filter: Vec, ) -> Result> { let is_inner_join = join.join_type == JoinType::Inner; + let allow_convert_filter_to_join_condition = !join.on.is_empty() + || !(matches!(join.left.max_rows(), Some(1)) + || matches!(join.right.max_rows(), Some(1))); // Get pushable predicates from current optimizer state let (left_preserved, right_preserved) = lr_is_preserved(join.join_type); @@ -431,7 +434,10 @@ fn push_down_all_join( left_push.push(predicate); } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); - } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { + } else if is_inner_join + && allow_convert_filter_to_join_condition + && can_evaluate_as_join_condition(&predicate)? + { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate // and convert to the join on condition join_conditions.push(predicate); @@ -1472,7 +1478,7 @@ mod tests { use crate::simplify_expressions::SimplifyExpressions; use crate::test::udfs::leaf_udf_expr; use crate::test::*; - use datafusion_expr::test::function_stub::sum; + use datafusion_expr::test::function_stub::{avg, sum}; use insta::assert_snapshot; use super::*; @@ -3574,6 +3580,34 @@ mod tests { ) } + #[test] + fn cross_join_with_scalar_side_keeps_post_join_filter() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a"), col("b")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a")])? + .aggregate(Vec::::new(), vec![avg(col("a"))])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("test.b").gt(col("AVG(test1.a)")))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > AVG(test1.a) + Cross Join: + Projection: test.a, test.b + TableScan: test + Aggregate: groupBy=[[]], aggr=[[avg(test1.a)]] + Projection: test1.a + TableScan: test1 + " + ) + } + #[test] fn left_semi_join() -> Result<()> { let left = test_table_scan_with_name("test1")?; From 00dcba4a70a9301c9f64c479cd47e31e1cbba7bf Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 22:29:00 +0800 Subject: [PATCH 13/63] new From 12de38255fdb3acf5502d269e7ec7e7df9733ca6 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 23:09:25 +0800 Subject: [PATCH 14/63] Simplify null handling and filter push down logic Refactor null checks in utils.rs by collapsing fast-path gates and removing duplicated branches for various null handling scenarios. Extract reusable helpers for null checks to improve readability. In push_down_filter.rs, streamline the scalar-cross-join guard and eliminate inline boolean expressions from the push_down_all_join function for clearer logic. --- datafusion/optimizer/src/push_down_filter.rs | 11 +- datafusion/optimizer/src/utils.rs | 117 ++++++++++--------- 2 files changed, 67 insertions(+), 61 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6ec374261153e..14ae63e1ad498 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -412,9 +412,8 @@ fn push_down_all_join( on_filter: Vec, ) -> Result> { let is_inner_join = join.join_type == JoinType::Inner; - let allow_convert_filter_to_join_condition = !join.on.is_empty() - || !(matches!(join.left.max_rows(), Some(1)) - || matches!(join.right.max_rows(), Some(1))); + let allow_convert_filter_to_join_condition = + allow_convert_filter_to_join_condition(&join); // Get pushable predicates from current optimizer state let (left_preserved, right_preserved) = lr_is_preserved(join.join_type); @@ -518,6 +517,12 @@ fn push_down_all_join( Ok(Transformed::yes(plan)) } +fn allow_convert_filter_to_join_condition(join: &Join) -> bool { + !join.on.is_empty() + || !(matches!(join.left.max_rows(), Some(1)) + || matches!(join.right.max_rows(), Some(1))) +} + fn push_down_join( join: Join, parent_predicate: Option<&Expr>, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index e02a5206c3e81..519170f8b54f5 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -81,9 +81,12 @@ static NULL_RESTRICTION_EVAL_MODE: AtomicU8 = static NULL_RESTRICTION_EVAL_MODE_LOCK: OnceLock> = OnceLock::new(); fn null_restriction_eval_mode() -> NullRestrictionEvalMode { - match NULL_RESTRICTION_EVAL_MODE.load(Ordering::Relaxed) { - 1 => NullRestrictionEvalMode::AuthoritativeOnly, - _ => NullRestrictionEvalMode::Auto, + if NULL_RESTRICTION_EVAL_MODE.load(Ordering::Relaxed) + == NullRestrictionEvalMode::AuthoritativeOnly as u8 + { + NullRestrictionEvalMode::AuthoritativeOnly + } else { + NullRestrictionEvalMode::Auto } } @@ -139,28 +142,27 @@ pub fn is_restrict_null_predicate<'a>( return Ok(false); } - if null_restriction_eval_mode() == NullRestrictionEvalMode::Auto { - if let Some(is_restricting) = + if null_restriction_eval_mode() == NullRestrictionEvalMode::Auto + && let Some(is_restricting) = syntactic_restrict_null_predicate(&predicate, &join_cols) + { + #[cfg(debug_assertions)] { - #[cfg(debug_assertions)] - { - let authoritative = authoritative_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - )?; - debug_assert_eq!( - is_restricting, authoritative, - "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" - ); - } - return Ok(is_restricting); + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); } + return Ok(is_restricting); } // If result is single `true`, return false; // If result is single `NULL` or `false`, return true; - authoritative_restrict_null_predicate(predicate, join_cols.into_iter()) + authoritative_restrict_null_predicate(predicate, join_cols) } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -193,6 +195,27 @@ fn predicate_uses_only_columns( .all(|column| allowed_columns.contains(*column)) } +fn contains_null( + values: impl IntoIterator>, +) -> bool { + values + .into_iter() + .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) +} + +fn null_check_value( + value: Option, + when_non_null: bool, +) -> Option { + match value { + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Boolean(false)), + Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { + Some(NullSubstitutionValue::Boolean(when_non_null)) + } + None => None, + } +} + fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, @@ -213,38 +236,20 @@ fn syntactic_null_substitution_value( Expr::Not(expr) => { sql_not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) } - Expr::IsNull(expr) => { - match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - Some(NullSubstitutionValue::Null) => { - Some(NullSubstitutionValue::Boolean(true)) - } - Some( - NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_), - ) => Some(NullSubstitutionValue::Boolean(false)), - None => None, - } - } - Expr::IsNotNull(expr) => { - match syntactic_null_substitution_value(expr.as_ref(), join_cols) { - Some(NullSubstitutionValue::Null) => { - Some(NullSubstitutionValue::Boolean(false)) - } - Some( - NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_), - ) => Some(NullSubstitutionValue::Boolean(true)), - None => None, - } - } + Expr::IsNull(expr) => sql_not(null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + )), + Expr::IsNotNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + ), Expr::Between(between) => { - let expr = - syntactic_null_substitution_value(between.expr.as_ref(), join_cols); - let low = syntactic_null_substitution_value(between.low.as_ref(), join_cols); - let high = - syntactic_null_substitution_value(between.high.as_ref(), join_cols); - if matches!(expr, Some(NullSubstitutionValue::Null)) - || matches!(low, Some(NullSubstitutionValue::Null)) - || matches!(high, Some(NullSubstitutionValue::Null)) - { + if contains_null([ + syntactic_null_substitution_value(between.expr.as_ref(), join_cols), + syntactic_null_substitution_value(between.low.as_ref(), join_cols), + syntactic_null_substitution_value(between.high.as_ref(), join_cols), + ]) { Some(NullSubstitutionValue::Null) } else { None @@ -256,12 +261,10 @@ fn syntactic_null_substitution_value( } Expr::Negative(expr) => strict_null_passthrough(expr.as_ref(), join_cols), Expr::Like(like) | Expr::SimilarTo(like) => { - let value = syntactic_null_substitution_value(like.expr.as_ref(), join_cols); - let pattern = - syntactic_null_substitution_value(like.pattern.as_ref(), join_cols); - if matches!(value, Some(NullSubstitutionValue::Null)) - || matches!(pattern, Some(NullSubstitutionValue::Null)) - { + if contains_null([ + syntactic_null_substitution_value(like.expr.as_ref(), join_cols), + syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), + ]) { Some(NullSubstitutionValue::Null) } else { None @@ -430,9 +433,7 @@ fn syntactic_binary_value( | datafusion_expr::Operator::QuestionAnd | datafusion_expr::Operator::QuestionPipe | datafusion_expr::Operator::Colon => { - if matches!(left, Some(NullSubstitutionValue::Null)) - || matches!(right, Some(NullSubstitutionValue::Null)) - { + if contains_null([left, right]) { Some(NullSubstitutionValue::Null) } else { None From 4fdeb7c4eaeea241929de3048c9d82357376cad3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 19 Mar 2026 23:10:46 +0800 Subject: [PATCH 15/63] clippy fix --- datafusion/optimizer/src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 519170f8b54f5..cc07321b3960a 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -162,7 +162,7 @@ pub fn is_restrict_null_predicate<'a>( // If result is single `true`, return false; // If result is single `NULL` or `false`, return true; - authoritative_restrict_null_predicate(predicate, join_cols) + authoritative_restrict_null_predicate(predicate, join_cols) } #[derive(Clone, Copy, Debug, PartialEq, Eq)] From 1d121724751f09b9c17358b5369f4cbba58e4902 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 09:25:53 +0800 Subject: [PATCH 16/63] Amend benchmark --- .../core/benches/sql_planner_extended.rs | 211 +++++++++--------- 1 file changed, 110 insertions(+), 101 deletions(-) diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index d4955313c79c3..ccd7d3f3f031b 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -27,12 +27,17 @@ use datafusion_expr::{cast, col, lit, not, try_cast, when}; use datafusion_functions::expr_fn::{ btrim, length, regexp_like, regexp_replace, to_timestamp, upper, }; +use std::env; use std::fmt::Write; use std::hint::black_box; use std::ops::Rem; use std::sync::Arc; use tokio::runtime::Runtime; +const FULL_PREDICATE_SWEEP: [usize; 5] = [10, 20, 30, 40, 60]; +const FULL_DEPTH_SWEEP: [usize; 3] = [1, 2, 3]; +const DEFAULT_SWEEP_POINTS: [(usize, usize); 3] = [(10, 1), (30, 2), (60, 3)]; + // This benchmark suite is designed to test the performance of // logical planning with a large plan containing unions, many columns // with a variety of operations in it. @@ -324,6 +329,27 @@ fn build_non_case_left_join_df_with_push_down_filter( rt.block_on(async { ctx.sql(&query).await.unwrap() }) } +fn include_full_push_down_filter_sweep() -> bool { + env::var("DATAFUSION_PUSH_DOWN_FILTER_FULL_SWEEP") + .map(|value| value == "1" || value.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +fn push_down_filter_sweep_points() -> Vec<(usize, usize)> { + if include_full_push_down_filter_sweep() { + FULL_DEPTH_SWEEP + .into_iter() + .flat_map(|depth| { + FULL_PREDICATE_SWEEP + .into_iter() + .map(move |predicate_count| (predicate_count, depth)) + }) + .collect() + } else { + DEFAULT_SWEEP_POINTS.to_vec() + } +} + fn criterion_benchmark(c: &mut Criterion) { let baseline_ctx = SessionContext::new(); let case_heavy_ctx = SessionContext::new(); @@ -349,115 +375,98 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let predicate_sweep = [10, 20, 30, 40, 60]; - let case_depth_sweep = [1, 2, 3]; - + let sweep_points = push_down_filter_sweep_points(); let mut hotspot_group = c.benchmark_group("push_down_filter_hotspot_case_heavy_left_join_ab"); - for case_depth in case_depth_sweep { - for predicate_count in predicate_sweep { - let with_push_down_filter = - build_case_heavy_left_join_df_with_push_down_filter( - &rt, - predicate_count, - case_depth, - true, - ); - let without_push_down_filter = - build_case_heavy_left_join_df_with_push_down_filter( - &rt, - predicate_count, - case_depth, - false, - ); - - let input_label = - format!("predicates={predicate_count},case_depth={case_depth}"); - // A/B interpretation: - // - with_push_down_filter: default optimizer path (rule enabled) - // - without_push_down_filter: control path with the rule removed - // Compare both IDs at the same sweep point to isolate rule impact. - hotspot_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - hotspot_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, + for &(predicate_count, case_depth) in &sweep_points { + let with_push_down_filter = build_case_heavy_left_join_df_with_push_down_filter( + &rt, + predicate_count, + case_depth, + true, + ); + let without_push_down_filter = + build_case_heavy_left_join_df_with_push_down_filter( + &rt, + predicate_count, + case_depth, + false, ); - } + + let input_label = format!("predicates={predicate_count},case_depth={case_depth}"); + // A/B interpretation: + // - with_push_down_filter: default optimizer path (rule enabled) + // - without_push_down_filter: control path with the rule removed + // Compare both IDs at the same sweep point to isolate rule impact. + hotspot_group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + hotspot_group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); } hotspot_group.finish(); let mut control_group = c.benchmark_group("push_down_filter_control_non_case_left_join_ab"); - for nesting_depth in case_depth_sweep { - for predicate_count in predicate_sweep { - let with_push_down_filter = build_non_case_left_join_df_with_push_down_filter( - &rt, - predicate_count, - nesting_depth, - true, - ); - let without_push_down_filter = - build_non_case_left_join_df_with_push_down_filter( - &rt, - predicate_count, - nesting_depth, - false, - ); - - let input_label = - format!("predicates={predicate_count},nesting_depth={nesting_depth}"); - control_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - control_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - } + for &(predicate_count, nesting_depth) in &sweep_points { + let with_push_down_filter = build_non_case_left_join_df_with_push_down_filter( + &rt, + predicate_count, + nesting_depth, + true, + ); + let without_push_down_filter = build_non_case_left_join_df_with_push_down_filter( + &rt, + predicate_count, + nesting_depth, + false, + ); + + let input_label = + format!("predicates={predicate_count},nesting_depth={nesting_depth}"); + control_group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + control_group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); } control_group.finish(); } From 2c5096dd7998213d94fbb6d61cef696dcbb4313d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 10:52:48 +0800 Subject: [PATCH 17/63] Fix alias for scalar aggregate in push_down_filter Assign stable alias avg_a to the scalar aggregate in push_down_filter.rs at line 3589. Update the filter and assertion to use this alias instead of the raw AVG(test1.a) string. This resolves the schema-name mismatch that caused FieldNotFound during plan construction, ensuring the test focuses on the core regression while keeping the filter positioned above the Cross Join when one side is scalar. --- datafusion/optimizer/src/push_down_filter.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 14ae63e1ad498..f3ced5e5bbf3f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3592,21 +3592,21 @@ mod tests { .build()?; let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) .project(vec![col("a")])? - .aggregate(Vec::::new(), vec![avg(col("a"))])? + .aggregate(Vec::::new(), vec![avg(col("a")).alias("avg_a")])? .build()?; let plan = LogicalPlanBuilder::from(left) .cross_join(right)? - .filter(col("test.b").gt(col("AVG(test1.a)")))? + .filter(col("test.b").gt(col("avg_a")))? .build()?; assert_optimized_plan_equal!( plan, @r" - Filter: test.b > AVG(test1.a) + Filter: test.b > avg_a Cross Join: Projection: test.a, test.b TableScan: test - Aggregate: groupBy=[[]], aggr=[[avg(test1.a)]] + Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_a]] Projection: test1.a TableScan: test1 " From 718cdf6fdb371d912bbc9a18abcdf409946f1819 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:04:13 +0800 Subject: [PATCH 18/63] Refactor optimizer: remove test-only controls Eliminate public test-only null-restriction from the API. Remove unnecessary test knobs from optimizer utilities and simplify related methods. Decouple core tests from optimizer internals to streamline the codebase and improve clarity. --- .../tests/sql/push_down_filter_regressions.rs | 18 +----- datafusion/optimizer/src/utils.rs | 57 +------------------ 2 files changed, 5 insertions(+), 70 deletions(-) diff --git a/datafusion/core/tests/sql/push_down_filter_regressions.rs b/datafusion/core/tests/sql/push_down_filter_regressions.rs index d683fd068c797..a1ff8293c97a1 100644 --- a/datafusion/core/tests/sql/push_down_filter_regressions.rs +++ b/datafusion/core/tests/sql/push_down_filter_regressions.rs @@ -19,9 +19,6 @@ use std::sync::Arc; use super::*; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; -use datafusion_optimizer::utils::{ - NullRestrictionEvalMode, set_null_restriction_eval_mode_for_test, -}; const WINDOW_SCALAR_SUBQUERY_SQL: &str = r#" WITH suppliers AS ( @@ -47,9 +44,7 @@ fn sqllogictest_style_ctx(push_down_filter_enabled: bool) -> SessionContext { async fn capture_window_scalar_subquery_plans( push_down_filter_enabled: bool, - null_restriction_eval_mode: NullRestrictionEvalMode, ) -> Result<(String, String)> { - let _mode_guard = set_null_restriction_eval_mode_for_test(null_restriction_eval_mode); let ctx = sqllogictest_style_ctx(push_down_filter_enabled); let df = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?; let optimized_plan = df.clone().into_optimized_plan()?; @@ -213,16 +208,9 @@ async fn natural_join_union_regression() -> Result<()> { #[tokio::test(flavor = "current_thread")] async fn window_scalar_subquery_optimizer_delta() -> Result<()> { let (enabled_optimized, enabled_physical) = - capture_window_scalar_subquery_plans(true, NullRestrictionEvalMode::Auto).await?; + capture_window_scalar_subquery_plans(true).await?; let (disabled_optimized, disabled_physical) = - capture_window_scalar_subquery_plans(false, NullRestrictionEvalMode::Auto) - .await?; - let (authoritative_optimized, authoritative_physical) = - capture_window_scalar_subquery_plans( - true, - NullRestrictionEvalMode::AuthoritativeOnly, - ) - .await?; + capture_window_scalar_subquery_plans(false).await?; assert!( enabled_optimized @@ -246,8 +234,6 @@ async fn window_scalar_subquery_optimizer_delta() -> Result<()> { assert_eq!(enabled_optimized, disabled_optimized); assert_eq!(enabled_physical, disabled_physical); - assert_eq!(authoritative_optimized, enabled_optimized); - assert_eq!(authoritative_physical, enabled_physical); Ok(()) } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index cc07321b3960a..d0c29669c117e 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,8 +18,7 @@ //! Utility functions leveraged by the query optimizer rules use std::collections::{BTreeSet, HashMap, HashSet}; -use std::sync::atomic::{AtomicU8, Ordering}; -use std::sync::{Arc, Mutex, MutexGuard, OnceLock}; +use std::sync::Arc; use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::array::{Array, RecordBatch, new_null_array}; @@ -69,55 +68,6 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { trace!("{description}::\n{}\n", plan.display_indent_schema()); } -#[doc(hidden)] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum NullRestrictionEvalMode { - Auto = 0, - AuthoritativeOnly = 1, -} - -static NULL_RESTRICTION_EVAL_MODE: AtomicU8 = - AtomicU8::new(NullRestrictionEvalMode::Auto as u8); -static NULL_RESTRICTION_EVAL_MODE_LOCK: OnceLock> = OnceLock::new(); - -fn null_restriction_eval_mode() -> NullRestrictionEvalMode { - if NULL_RESTRICTION_EVAL_MODE.load(Ordering::Relaxed) - == NullRestrictionEvalMode::AuthoritativeOnly as u8 - { - NullRestrictionEvalMode::AuthoritativeOnly - } else { - NullRestrictionEvalMode::Auto - } -} - -#[doc(hidden)] -pub struct NullRestrictionEvalModeGuard { - previous_mode: NullRestrictionEvalMode, - _lock: MutexGuard<'static, ()>, -} - -impl Drop for NullRestrictionEvalModeGuard { - fn drop(&mut self) { - NULL_RESTRICTION_EVAL_MODE.store(self.previous_mode as u8, Ordering::Relaxed); - } -} - -#[doc(hidden)] -pub fn set_null_restriction_eval_mode_for_test( - mode: NullRestrictionEvalMode, -) -> NullRestrictionEvalModeGuard { - let lock = NULL_RESTRICTION_EVAL_MODE_LOCK - .get_or_init(|| Mutex::new(())) - .lock() - .expect("null restriction mode lock poisoned"); - let previous_mode = null_restriction_eval_mode(); - NULL_RESTRICTION_EVAL_MODE.store(mode as u8, Ordering::Relaxed); - NullRestrictionEvalModeGuard { - previous_mode, - _lock: lock, - } -} - /// Determine whether a predicate can restrict NULLs. e.g. /// `c0 > 8` return true; /// `c0 IS NULL` return false. @@ -142,9 +92,8 @@ pub fn is_restrict_null_predicate<'a>( return Ok(false); } - if null_restriction_eval_mode() == NullRestrictionEvalMode::Auto - && let Some(is_restricting) = - syntactic_restrict_null_predicate(&predicate, &join_cols) + if let Some(is_restricting) = + syntactic_restrict_null_predicate(&predicate, &join_cols) { #[cfg(debug_assertions)] { From deb679910a3d3bb12cb1592c7c2a54b9e09c3a6f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:23:50 +0800 Subject: [PATCH 19/63] Add domain-specific evaluator for null restrictions Implement new module `null_restriction.rs` containing syntactic restriction logic and internal three-valued interpretation helpers. Streamline orchestration in `utils.rs` to delegate tasks to the new module, ensuring efficiency and clarity. Update evaluator-equivalence test to maintain assertion coverage with the new structure. --- datafusion/optimizer/src/utils.rs | 291 +---------------- .../optimizer/src/utils/null_restriction.rs | 301 ++++++++++++++++++ 2 files changed, 308 insertions(+), 284 deletions(-) create mode 100644 datafusion/optimizer/src/utils/null_restriction.rs diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d0c29669c117e..661fe71b76eda 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,6 +17,8 @@ //! Utility functions leveraged by the query optimizer rules +mod null_restriction; + use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; @@ -88,12 +90,12 @@ pub fn is_restrict_null_predicate<'a>( // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !predicate_uses_only_columns(&predicate, &join_cols) { + if !null_restriction::predicate_uses_only_columns(&predicate, &join_cols) { return Ok(false); } if let Some(is_restricting) = - syntactic_restrict_null_predicate(&predicate, &join_cols) + null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols) { #[cfg(debug_assertions)] { @@ -114,285 +116,6 @@ pub fn is_restrict_null_predicate<'a>( authoritative_restrict_null_predicate(predicate, join_cols) } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum NullSubstitutionValue { - Null, - NonNull, - Boolean(bool), -} - -fn syntactic_restrict_null_predicate( - predicate: &Expr, - join_cols: &HashSet<&Column>, -) -> Option { - match syntactic_null_substitution_value(predicate, join_cols) { - Some(NullSubstitutionValue::Boolean(true)) => Some(false), - Some(NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null) => { - Some(true) - } - Some(NullSubstitutionValue::NonNull) | None => None, - } -} - -fn predicate_uses_only_columns( - predicate: &Expr, - allowed_columns: &HashSet<&Column>, -) -> bool { - predicate - .column_refs() - .iter() - .all(|column| allowed_columns.contains(*column)) -} - -fn contains_null( - values: impl IntoIterator>, -) -> bool { - values - .into_iter() - .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) -} - -fn null_check_value( - value: Option, - when_non_null: bool, -) -> Option { - match value { - Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Boolean(false)), - Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { - Some(NullSubstitutionValue::Boolean(when_non_null)) - } - None => None, - } -} - -fn syntactic_null_substitution_value( - expr: &Expr, - join_cols: &HashSet<&Column>, -) -> Option { - match expr { - Expr::Alias(alias) => { - syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) - } - Expr::Column(column) => { - if join_cols.contains(column) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } - Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), - Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), - Expr::Not(expr) => { - sql_not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) - } - Expr::IsNull(expr) => sql_not(null_check_value( - syntactic_null_substitution_value(expr.as_ref(), join_cols), - true, - )), - Expr::IsNotNull(expr) => null_check_value( - syntactic_null_substitution_value(expr.as_ref(), join_cols), - true, - ), - Expr::Between(between) => { - if contains_null([ - syntactic_null_substitution_value(between.expr.as_ref(), join_cols), - syntactic_null_substitution_value(between.low.as_ref(), join_cols), - syntactic_null_substitution_value(between.high.as_ref(), join_cols), - ]) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } - Expr::Cast(cast) => strict_null_passthrough(cast.expr.as_ref(), join_cols), - Expr::TryCast(try_cast) => { - strict_null_passthrough(try_cast.expr.as_ref(), join_cols) - } - Expr::Negative(expr) => strict_null_passthrough(expr.as_ref(), join_cols), - Expr::Like(like) | Expr::SimilarTo(like) => { - if contains_null([ - syntactic_null_substitution_value(like.expr.as_ref(), join_cols), - syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), - ]) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } - Expr::Exists { .. } - | Expr::InList(_) - | Expr::InSubquery(_) - | Expr::SetComparison(_) - | Expr::ScalarSubquery(_) - | Expr::OuterReferenceColumn(_, _) - | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) - | Expr::Unnest(_) - | Expr::GroupingSet(_) - | Expr::WindowFunction(_) - | Expr::ScalarFunction(_) - | Expr::Case(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) => None, - Expr::AggregateFunction(_) => None, - // TODO: remove the next line after `Expr::Wildcard` is removed - #[expect(deprecated)] - Expr::Wildcard { .. } => None, - } -} - -fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { - if value.is_null() { - NullSubstitutionValue::Null - } else if let ScalarValue::Boolean(Some(value)) = value { - NullSubstitutionValue::Boolean(*value) - } else { - NullSubstitutionValue::NonNull - } -} - -fn strict_null_passthrough( - expr: &Expr, - join_cols: &HashSet<&Column>, -) -> Option { - if matches!( - syntactic_null_substitution_value(expr, join_cols), - Some(NullSubstitutionValue::Null) - ) { - Some(NullSubstitutionValue::Null) - } else { - None - } -} - -fn sql_not(value: Option) -> Option { - match value { - Some(NullSubstitutionValue::Boolean(value)) => { - Some(NullSubstitutionValue::Boolean(!value)) - } - Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), - Some(NullSubstitutionValue::NonNull) | None => None, - } -} - -fn sql_and( - left: Option, - right: Option, -) -> Option { - if matches!(left, Some(NullSubstitutionValue::Boolean(false))) - || matches!(right, Some(NullSubstitutionValue::Boolean(false))) - { - return Some(NullSubstitutionValue::Boolean(false)); - } - - match (left, right) { - (Some(NullSubstitutionValue::Boolean(true)), value) - | (value, Some(NullSubstitutionValue::Boolean(true))) => value, - (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { - Some(NullSubstitutionValue::Null) - } - (Some(NullSubstitutionValue::NonNull), _) - | (_, Some(NullSubstitutionValue::NonNull)) - | (None, _) - | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } - } -} - -fn sql_or( - left: Option, - right: Option, -) -> Option { - if matches!(left, Some(NullSubstitutionValue::Boolean(true))) - || matches!(right, Some(NullSubstitutionValue::Boolean(true))) - { - return Some(NullSubstitutionValue::Boolean(true)); - } - - match (left, right) { - (Some(NullSubstitutionValue::Boolean(false)), value) - | (value, Some(NullSubstitutionValue::Boolean(false))) => value, - (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { - Some(NullSubstitutionValue::Null) - } - (Some(NullSubstitutionValue::NonNull), _) - | (_, Some(NullSubstitutionValue::NonNull)) - | (None, _) - | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } - } -} - -fn syntactic_binary_value( - binary_expr: &datafusion_expr::BinaryExpr, - join_cols: &HashSet<&Column>, -) -> Option { - let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); - let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); - - match binary_expr.op { - datafusion_expr::Operator::And => sql_and(left, right), - datafusion_expr::Operator::Or => sql_or(left, right), - datafusion_expr::Operator::Eq - | datafusion_expr::Operator::NotEq - | datafusion_expr::Operator::Lt - | datafusion_expr::Operator::LtEq - | datafusion_expr::Operator::Gt - | datafusion_expr::Operator::GtEq - | datafusion_expr::Operator::Plus - | datafusion_expr::Operator::Minus - | datafusion_expr::Operator::Multiply - | datafusion_expr::Operator::Divide - | datafusion_expr::Operator::Modulo - | datafusion_expr::Operator::RegexMatch - | datafusion_expr::Operator::RegexIMatch - | datafusion_expr::Operator::RegexNotMatch - | datafusion_expr::Operator::RegexNotIMatch - | datafusion_expr::Operator::LikeMatch - | datafusion_expr::Operator::ILikeMatch - | datafusion_expr::Operator::NotLikeMatch - | datafusion_expr::Operator::NotILikeMatch - | datafusion_expr::Operator::BitwiseAnd - | datafusion_expr::Operator::BitwiseOr - | datafusion_expr::Operator::BitwiseXor - | datafusion_expr::Operator::BitwiseShiftRight - | datafusion_expr::Operator::BitwiseShiftLeft - | datafusion_expr::Operator::StringConcat - | datafusion_expr::Operator::AtArrow - | datafusion_expr::Operator::ArrowAt - | datafusion_expr::Operator::Arrow - | datafusion_expr::Operator::LongArrow - | datafusion_expr::Operator::HashArrow - | datafusion_expr::Operator::HashLongArrow - | datafusion_expr::Operator::AtAt - | datafusion_expr::Operator::IntegerDivide - | datafusion_expr::Operator::HashMinus - | datafusion_expr::Operator::AtQuestion - | datafusion_expr::Operator::Question - | datafusion_expr::Operator::QuestionAnd - | datafusion_expr::Operator::QuestionPipe - | datafusion_expr::Operator::Colon => { - if contains_null([left, right]) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } - datafusion_expr::Operator::IsDistinctFrom - | datafusion_expr::Operator::IsNotDistinctFrom => None, - } -} - /// Determines if an expression will always evaluate to null. /// `c0 + 8` return true /// `c0 IS NULL` return false @@ -664,9 +387,9 @@ mod tests { for predicate in test_cases { let join_cols = predicate.column_refs(); - if let Some(syntactic) = - syntactic_restrict_null_predicate(&predicate, &join_cols) - { + if let Some(syntactic) = null_restriction::syntactic_restrict_null_predicate( + &predicate, &join_cols, + ) { let authoritative = authoritative_restrict_null_predicate( predicate.clone(), join_cols.iter().copied(), diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs new file mode 100644 index 0000000000000..e23a7286c5ec9 --- /dev/null +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Syntactic null-restriction evaluator used by optimizer fast paths. + +use std::collections::HashSet; + +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::{BinaryExpr, Expr, Operator}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum NullSubstitutionValue { + Null, + NonNull, + Boolean(bool), +} + +pub(super) fn syntactic_restrict_null_predicate( + predicate: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match syntactic_null_substitution_value(predicate, join_cols) { + Some(NullSubstitutionValue::Boolean(true)) => Some(false), + Some(NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null) => { + Some(true) + } + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +pub(super) fn predicate_uses_only_columns( + predicate: &Expr, + allowed_columns: &HashSet<&Column>, +) -> bool { + predicate + .column_refs() + .iter() + .all(|column| allowed_columns.contains(*column)) +} + +fn contains_null( + values: impl IntoIterator>, +) -> bool { + values + .into_iter() + .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) +} + +fn null_check_value( + value: Option, + when_non_null: bool, +) -> Option { + match value { + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Boolean(false)), + Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { + Some(NullSubstitutionValue::Boolean(when_non_null)) + } + None => None, + } +} + +fn syntactic_null_substitution_value( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match expr { + Expr::Alias(alias) => { + syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) + } + Expr::Column(column) => { + if join_cols.contains(column) { + Some(NullSubstitutionValue::Null) + } else { + None + } + } + Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), + Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), + Expr::Not(expr) => { + sql_not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + } + Expr::IsNull(expr) => sql_not(null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + )), + Expr::IsNotNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + ), + Expr::Between(between) => { + if contains_null([ + syntactic_null_substitution_value(between.expr.as_ref(), join_cols), + syntactic_null_substitution_value(between.low.as_ref(), join_cols), + syntactic_null_substitution_value(between.high.as_ref(), join_cols), + ]) { + Some(NullSubstitutionValue::Null) + } else { + None + } + } + Expr::Cast(cast) => strict_null_passthrough(cast.expr.as_ref(), join_cols), + Expr::TryCast(try_cast) => { + strict_null_passthrough(try_cast.expr.as_ref(), join_cols) + } + Expr::Negative(expr) => strict_null_passthrough(expr.as_ref(), join_cols), + Expr::Like(like) | Expr::SimilarTo(like) => { + if contains_null([ + syntactic_null_substitution_value(like.expr.as_ref(), join_cols), + syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), + ]) { + Some(NullSubstitutionValue::Null) + } else { + None + } + } + Expr::Exists { .. } + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::Placeholder(_) + | Expr::ScalarVariable(_, _) + | Expr::Unnest(_) + | Expr::GroupingSet(_) + | Expr::WindowFunction(_) + | Expr::ScalarFunction(_) + | Expr::Case(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) => None, + Expr::AggregateFunction(_) => None, + // TODO: remove the next line after `Expr::Wildcard` is removed + #[expect(deprecated)] + Expr::Wildcard { .. } => None, + } +} + +fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { + if value.is_null() { + NullSubstitutionValue::Null + } else if let ScalarValue::Boolean(Some(value)) = value { + NullSubstitutionValue::Boolean(*value) + } else { + NullSubstitutionValue::NonNull + } +} + +fn strict_null_passthrough( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + if matches!( + syntactic_null_substitution_value(expr, join_cols), + Some(NullSubstitutionValue::Null) + ) { + Some(NullSubstitutionValue::Null) + } else { + None + } +} + +fn sql_not(value: Option) -> Option { + match value { + Some(NullSubstitutionValue::Boolean(value)) => { + Some(NullSubstitutionValue::Boolean(!value)) + } + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +fn sql_and( + left: Option, + right: Option, +) -> Option { + if matches!(left, Some(NullSubstitutionValue::Boolean(false))) + || matches!(right, Some(NullSubstitutionValue::Boolean(false))) + { + return Some(NullSubstitutionValue::Boolean(false)); + } + + match (left, right) { + (Some(NullSubstitutionValue::Boolean(true)), value) + | (value, Some(NullSubstitutionValue::Boolean(true))) => value, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + +fn sql_or( + left: Option, + right: Option, +) -> Option { + if matches!(left, Some(NullSubstitutionValue::Boolean(true))) + || matches!(right, Some(NullSubstitutionValue::Boolean(true))) + { + return Some(NullSubstitutionValue::Boolean(true)); + } + + match (left, right) { + (Some(NullSubstitutionValue::Boolean(false)), value) + | (value, Some(NullSubstitutionValue::Boolean(false))) => value, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + +fn syntactic_binary_value( + binary_expr: &BinaryExpr, + join_cols: &HashSet<&Column>, +) -> Option { + let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); + let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); + + match binary_expr.op { + Operator::And => sql_and(left, right), + Operator::Or => sql_or(left, right), + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + | Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + | Operator::StringConcat + | Operator::AtArrow + | Operator::ArrowAt + | Operator::Arrow + | Operator::LongArrow + | Operator::HashArrow + | Operator::HashLongArrow + | Operator::AtAt + | Operator::IntegerDivide + | Operator::HashMinus + | Operator::AtQuestion + | Operator::Question + | Operator::QuestionAnd + | Operator::QuestionPipe + | Operator::Colon => { + if contains_null([left, right]) { + Some(NullSubstitutionValue::Null) + } else { + None + } + } + Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => None, + } +} From aa030ad1d51b48b58ec6ae282de84e8649754b04 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:27:04 +0800 Subject: [PATCH 20/63] Improve null_restriction with documentation and evaluator Add explicit lattice documentation for the null-substitution domain, clarifying the meaning of NonNull and its conservative nature. Introduce an evaluator struct with methods for logical operations, and route expression evaluation to these methods to enhance readability and maintainability while preserving existing behavior around NonNull semantics. --- .../optimizer/src/utils/null_restriction.rs | 159 ++++++++++-------- 1 file changed, 91 insertions(+), 68 deletions(-) diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index e23a7286c5ec9..dd92bbfea03ab 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -24,11 +24,95 @@ use datafusion_expr::{BinaryExpr, Expr, Operator}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum NullSubstitutionValue { + /// SQL NULL after substituting join columns with NULL. Null, + /// Known to be non-null, but value is otherwise unknown. NonNull, + /// A known boolean outcome from SQL three-valued logic. Boolean(bool), } +/// Evaluates a subset of SQL three-valued logic over null substitution values. +/// +/// Lattice used by the syntactic fast path: +/// - `Boolean(true|false)`: exact logical value +/// - `Null`: exact SQL unknown/null value +/// - `NonNull`: known to be not null, but exact value is unknown +/// +/// `NonNull` is intentionally conservative for logical operators. For example, +/// `NonNull AND true` could be either `true` or `false`, so the result remains +/// unknown to the syntactic evaluator (`None`) rather than pretending to know. +#[derive(Default)] +struct SqlThreeValuedEvaluator; + +impl SqlThreeValuedEvaluator { + fn not(&self, value: Option) -> Option { + match value { + Some(NullSubstitutionValue::Boolean(value)) => { + Some(NullSubstitutionValue::Boolean(!value)) + } + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), + Some(NullSubstitutionValue::NonNull) | None => None, + } + } + + fn and( + &self, + left: Option, + right: Option, + ) -> Option { + if matches!(left, Some(NullSubstitutionValue::Boolean(false))) + || matches!(right, Some(NullSubstitutionValue::Boolean(false))) + { + return Some(NullSubstitutionValue::Boolean(false)); + } + + match (left, right) { + (Some(NullSubstitutionValue::Boolean(true)), value) + | (value, Some(NullSubstitutionValue::Boolean(true))) => value, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } + } + + fn or( + &self, + left: Option, + right: Option, + ) -> Option { + if matches!(left, Some(NullSubstitutionValue::Boolean(true))) + || matches!(right, Some(NullSubstitutionValue::Boolean(true))) + { + return Some(NullSubstitutionValue::Boolean(true)); + } + + match (left, right) { + (Some(NullSubstitutionValue::Boolean(false)), value) + | (value, Some(NullSubstitutionValue::Boolean(false))) => value, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } + } +} + pub(super) fn syntactic_restrict_null_predicate( predicate: &Expr, join_cols: &HashSet<&Column>, @@ -77,6 +161,8 @@ fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, ) -> Option { + let evaluator = SqlThreeValuedEvaluator; + match expr { Expr::Alias(alias) => { syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) @@ -91,9 +177,9 @@ fn syntactic_null_substitution_value( Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), Expr::Not(expr) => { - sql_not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + evaluator.not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) } - Expr::IsNull(expr) => sql_not(null_check_value( + Expr::IsNull(expr) => evaluator.not(null_check_value( syntactic_null_substitution_value(expr.as_ref(), join_cols), true, )), @@ -177,80 +263,17 @@ fn strict_null_passthrough( } } -fn sql_not(value: Option) -> Option { - match value { - Some(NullSubstitutionValue::Boolean(value)) => { - Some(NullSubstitutionValue::Boolean(!value)) - } - Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), - Some(NullSubstitutionValue::NonNull) | None => None, - } -} - -fn sql_and( - left: Option, - right: Option, -) -> Option { - if matches!(left, Some(NullSubstitutionValue::Boolean(false))) - || matches!(right, Some(NullSubstitutionValue::Boolean(false))) - { - return Some(NullSubstitutionValue::Boolean(false)); - } - - match (left, right) { - (Some(NullSubstitutionValue::Boolean(true)), value) - | (value, Some(NullSubstitutionValue::Boolean(true))) => value, - (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { - Some(NullSubstitutionValue::Null) - } - (Some(NullSubstitutionValue::NonNull), _) - | (_, Some(NullSubstitutionValue::NonNull)) - | (None, _) - | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } - } -} - -fn sql_or( - left: Option, - right: Option, -) -> Option { - if matches!(left, Some(NullSubstitutionValue::Boolean(true))) - || matches!(right, Some(NullSubstitutionValue::Boolean(true))) - { - return Some(NullSubstitutionValue::Boolean(true)); - } - - match (left, right) { - (Some(NullSubstitutionValue::Boolean(false)), value) - | (value, Some(NullSubstitutionValue::Boolean(false))) => value, - (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { - Some(NullSubstitutionValue::Null) - } - (Some(NullSubstitutionValue::NonNull), _) - | (_, Some(NullSubstitutionValue::NonNull)) - | (None, _) - | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } - } -} - fn syntactic_binary_value( binary_expr: &BinaryExpr, join_cols: &HashSet<&Column>, ) -> Option { let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); + let evaluator = SqlThreeValuedEvaluator; match binary_expr.op { - Operator::And => sql_and(left, right), - Operator::Or => sql_or(left, right), + Operator::And => evaluator.and(left, right), + Operator::Or => evaluator.or(left, right), Operator::Eq | Operator::NotEq | Operator::Lt From 9cae4030129efb0ec6c84f2169cc3ba66986d66b Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:29:19 +0800 Subject: [PATCH 21/63] Rename and clarify filter promotion function Update function name to `can_promote_post_join_filter_to_join_condition` for better intent clarity. Add rich documentation explaining the purpose and heuristics used to protect scalar-side joins and cross joins from unsafe promotion of filters to join conditions. Maintain previous semantic guard to disallow promotion when one side is scalar, while allowing it in other scenarios. Update the callsite in `push_down_all_join` to use the new function name. --- datafusion/optimizer/src/push_down_filter.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f3ced5e5bbf3f..8b04cf45b2a47 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -412,8 +412,8 @@ fn push_down_all_join( on_filter: Vec, ) -> Result> { let is_inner_join = join.join_type == JoinType::Inner; - let allow_convert_filter_to_join_condition = - allow_convert_filter_to_join_condition(&join); + let can_promote_post_join_filter_to_join_condition = + can_promote_post_join_filter_to_join_condition(&join); // Get pushable predicates from current optimizer state let (left_preserved, right_preserved) = lr_is_preserved(join.join_type); @@ -434,7 +434,7 @@ fn push_down_all_join( } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join - && allow_convert_filter_to_join_condition + && can_promote_post_join_filter_to_join_condition && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate @@ -517,7 +517,15 @@ fn push_down_all_join( Ok(Transformed::yes(plan)) } -fn allow_convert_filter_to_join_condition(join: &Join) -> bool { +/// Returns true when post-join filters are allowed to be promoted to join conditions. +/// +/// Protection is necessary for scalar-side joins and cross joins to avoid incorrectly +/// rewriting a post-join filter into the join condition when one side is empty or +/// limited to at most one row (`max_rows() == Some(1)`). +/// +/// - `join.on` non-empty means existing join predicates already exist; promotion is safe. +/// - if neither side is scalar (`max_rows() == Some(1)`), promotion is safe. +fn can_promote_post_join_filter_to_join_condition(join: &Join) -> bool { !join.on.is_empty() || !(matches!(join.left.max_rows(), Some(1)) || matches!(join.right.max_rows(), Some(1))) From 11098c13fad4c35cffb729b8a3fc3ad8c433cfbb Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:31:15 +0800 Subject: [PATCH 22/63] Refactor pushdown filter logic for clarity Make the null restriction policy explicit in InferredPredicates::try_build_predicate by using a new helper function. The change clarifies that only predicates which restrict nulls allow pushdown, while maintaining the previous behavior. This avoids potential issues with the previous implicit logic that used unwrap_or(false). --- datafusion/optimizer/src/push_down_filter.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8b04cf45b2a47..59d8e80a4f76a 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -642,8 +642,10 @@ impl InferredPredicates { replace_map: &HashMap<&Column, &Column>, ) -> Result<()> { if self.is_inner_join - || is_restrict_null_predicate(predicate.clone(), replace_map.keys().cloned()) - .unwrap_or(false) + || is_restrict_null_predicate_allows_pushdown( + predicate.clone(), + replace_map.keys().cloned(), + ) { self.predicates.push(replace_col(predicate, replace_map)?); } @@ -652,6 +654,13 @@ impl InferredPredicates { } } +fn is_restrict_null_predicate_allows_pushdown<'a>( + predicate: Expr, + join_cols: impl IntoIterator, +) -> bool { + matches!(is_restrict_null_predicate(predicate, join_cols), Ok(true)) +} + /// Infer predicates from the pushed down predicates. /// /// Parameters From 84ca1b9faf7b34b144399876e5b0daa417fe9ef6 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:39:04 +0800 Subject: [PATCH 23/63] Add null restriction eval mode and test controls Introduce NullRestrictionEvalMode with Auto and AuthoritativeOnly options. Add test-only mode control using a static Mutex for testing purposes. Implement a runtime mode switch helper. Update the is_restrict_null_predicate to branch logic based on the mode. In production, always default to Auto mode. --- datafusion/optimizer/src/utils.rs | 99 +++++++++++++++++++++++++------ 1 file changed, 81 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 661fe71b76eda..009a42a8d9b78 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -22,7 +22,37 @@ mod null_restriction; use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; +#[cfg(test)] +use std::sync::Mutex; + use crate::analyzer::type_coercion::TypeCoercionRewriter; + +/// Null restriction evaluation mode for optimizer tests. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum NullRestrictionEvalMode { + Auto, + AuthoritativeOnly, +} + +#[cfg(test)] +static NULL_RESTRICTION_EVAL_MODE: Mutex = + Mutex::new(NullRestrictionEvalMode::Auto); + +#[cfg(test)] +pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { + *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() = mode; +} + +fn null_restriction_eval_mode() -> NullRestrictionEvalMode { + #[cfg(test)] + { + *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() + } + #[cfg(not(test))] + { + NullRestrictionEvalMode::Auto + } +} use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; @@ -94,26 +124,35 @@ pub fn is_restrict_null_predicate<'a>( return Ok(false); } - if let Some(is_restricting) = - null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols) - { - #[cfg(debug_assertions)] - { - let authoritative = authoritative_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - )?; - debug_assert_eq!( - is_restricting, authoritative, - "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" - ); + let mode = null_restriction_eval_mode(); + + match mode { + NullRestrictionEvalMode::AuthoritativeOnly => { + authoritative_restrict_null_predicate(predicate, join_cols) + } + NullRestrictionEvalMode::Auto => { + if let Some(is_restricting) = + null_restriction::syntactic_restrict_null_predicate( + &predicate, &join_cols, + ) + { + #[cfg(debug_assertions)] + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); + } + Ok(is_restricting) + } else { + authoritative_restrict_null_predicate(predicate, join_cols) + } } - return Ok(is_restricting); } - - // If result is single `true`, return false; - // If result is single `NULL` or `false`, return true; - authoritative_restrict_null_predicate(predicate, join_cols) } /// Determines if an expression will always evaluate to null. @@ -408,4 +447,28 @@ mod tests { Ok(()) } + + #[test] + fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> { + let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); + let join_cols_of_predicate = predicate.column_refs(); + + set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + let auto_result = is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + )?; + + set_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + ); + let authoritative_result = is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + )?; + + assert_eq!(auto_result, authoritative_result); + + Ok(()) + } } From fae294dce720c9a87e3f83c3fbc0137edbcef85a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:44:43 +0800 Subject: [PATCH 24/63] Refactor SQL planner for left join with filters Introduce a generic DataFrame builder for left joins that handles session context creation, table registration, and optional push-down filter removal. Rewire existing functions to leverage the new generic builder. Add an abstract benchmarking helper to replace duplicated blocks in the criterion_benchmark for case-heavy and non-case groups. --- .../core/benches/sql_planner_extended.rs | 193 ++++++++++-------- 1 file changed, 105 insertions(+), 88 deletions(-) diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index ccd7d3f3f031b..012587b064700 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -18,7 +18,10 @@ use arrow::array::{ArrayRef, RecordBatch}; use arrow_schema::DataType; use arrow_schema::TimeUnit::Nanosecond; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use criterion::{ + BenchmarkGroup, BenchmarkId, Criterion, criterion_group, criterion_main, + measurement::WallTime, +}; use datafusion::prelude::{DataFrame, SessionContext}; use datafusion_catalog::MemTable; use datafusion_common::ScalarValue; @@ -309,10 +312,11 @@ fn build_non_case_left_join_query( query } -fn build_non_case_left_join_df_with_push_down_filter( +fn build_left_join_df_with_push_down_filter( rt: &Runtime, + query_builder: impl Fn(usize, usize) -> String, predicate_count: usize, - nesting_depth: usize, + depth: usize, push_down_filter_enabled: bool, ) -> DataFrame { let ctx = SessionContext::new(); @@ -325,10 +329,40 @@ fn build_non_case_left_join_df_with_push_down_filter( ); } - let query = build_non_case_left_join_query(predicate_count, nesting_depth); + let query = query_builder(predicate_count, depth); rt.block_on(async { ctx.sql(&query).await.unwrap() }) } +fn build_case_heavy_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + case_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + build_left_join_df_with_push_down_filter( + rt, + build_case_heavy_left_join_query, + predicate_count, + case_depth, + push_down_filter_enabled, + ) +} + +fn build_non_case_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + nesting_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + build_left_join_df_with_push_down_filter( + rt, + build_non_case_left_join_query, + predicate_count, + nesting_depth, + push_down_filter_enabled, + ) +} + fn include_full_push_down_filter_sweep() -> bool { env::var("DATAFUSION_PUSH_DOWN_FILTER_FULL_SWEEP") .map(|value| value == "1" || value.eq_ignore_ascii_case("true")) @@ -350,6 +384,48 @@ fn push_down_filter_sweep_points() -> Vec<(usize, usize)> { } } +fn bench_push_down_filter_ab( + group: &mut BenchmarkGroup<'_, WallTime>, + rt: &Runtime, + sweep_points: &[(usize, usize)], + build_df: BuildFn, +) where + BuildFn: Fn(&Runtime, usize, usize, bool) -> DataFrame, +{ + for &(predicate_count, depth) in sweep_points { + let with_push_down_filter = build_df(rt, predicate_count, depth, true); + let without_push_down_filter = build_df(rt, predicate_count, depth, false); + + let input_label = format!("predicates={predicate_count},nesting_depth={depth}"); + + group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + } +} + fn criterion_benchmark(c: &mut Criterion) { let baseline_ctx = SessionContext::new(); let case_heavy_ctx = SessionContext::new(); @@ -376,98 +452,39 @@ fn criterion_benchmark(c: &mut Criterion) { }); let sweep_points = push_down_filter_sweep_points(); + let mut hotspot_group = c.benchmark_group("push_down_filter_hotspot_case_heavy_left_join_ab"); - for &(predicate_count, case_depth) in &sweep_points { - let with_push_down_filter = build_case_heavy_left_join_df_with_push_down_filter( - &rt, - predicate_count, - case_depth, - true, - ); - let without_push_down_filter = + bench_push_down_filter_ab( + &mut hotspot_group, + &rt, + &sweep_points, + |rt, predicate_count, depth, enable| { build_case_heavy_left_join_df_with_push_down_filter( - &rt, + rt, predicate_count, - case_depth, - false, - ); - - let input_label = format!("predicates={predicate_count},case_depth={case_depth}"); - // A/B interpretation: - // - with_push_down_filter: default optimizer path (rule enabled) - // - without_push_down_filter: control path with the rule removed - // Compare both IDs at the same sweep point to isolate rule impact. - hotspot_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), - ); - }) - }, - ); - hotspot_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), - ); - }) - }, - ); - } + depth, + enable, + ) + }, + ); hotspot_group.finish(); let mut control_group = c.benchmark_group("push_down_filter_control_non_case_left_join_ab"); - for &(predicate_count, nesting_depth) in &sweep_points { - let with_push_down_filter = build_non_case_left_join_df_with_push_down_filter( - &rt, - predicate_count, - nesting_depth, - true, - ); - let without_push_down_filter = build_non_case_left_join_df_with_push_down_filter( - &rt, - predicate_count, - nesting_depth, - false, - ); - - let input_label = - format!("predicates={predicate_count},nesting_depth={nesting_depth}"); - control_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), - ); - }) - }, - ); - control_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), - ); - }) - }, - ); - } + bench_push_down_filter_ab( + &mut control_group, + &rt, + &sweep_points, + |rt, predicate_count, depth, enable| { + build_non_case_left_join_df_with_push_down_filter( + rt, + predicate_count, + depth, + enable, + ) + }, + ); control_group.finish(); } From 06242040649b85d4dfbd487980d49f79542dccb9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:55:31 +0800 Subject: [PATCH 25/63] Refactor null restriction logic in null_restriction.rs Simplify null propagation with small helper functions. Improve IS NULL/IS NOT NULL handling and streamline control flow by using more idiomatic constructs. Flatten is_restrict_null_predicate in utils.rs and eliminate an unnecessary wrapper in push_down_filter.rs for clearer logic. --- datafusion/optimizer/src/push_down_filter.rs | 20 +- datafusion/optimizer/src/utils.rs | 48 ++-- .../optimizer/src/utils/null_restriction.rs | 216 ++++++------------ 3 files changed, 105 insertions(+), 179 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 59d8e80a4f76a..14286a4480835 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -412,8 +412,6 @@ fn push_down_all_join( on_filter: Vec, ) -> Result> { let is_inner_join = join.join_type == JoinType::Inner; - let can_promote_post_join_filter_to_join_condition = - can_promote_post_join_filter_to_join_condition(&join); // Get pushable predicates from current optimizer state let (left_preserved, right_preserved) = lr_is_preserved(join.join_type); @@ -434,7 +432,7 @@ fn push_down_all_join( } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join - && can_promote_post_join_filter_to_join_condition + && can_promote_post_join_filter_to_join_condition(&join) && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate @@ -642,9 +640,12 @@ impl InferredPredicates { replace_map: &HashMap<&Column, &Column>, ) -> Result<()> { if self.is_inner_join - || is_restrict_null_predicate_allows_pushdown( - predicate.clone(), - replace_map.keys().cloned(), + || matches!( + is_restrict_null_predicate( + predicate.clone(), + replace_map.keys().cloned() + ), + Ok(true) ) { self.predicates.push(replace_col(predicate, replace_map)?); @@ -654,13 +655,6 @@ impl InferredPredicates { } } -fn is_restrict_null_predicate_allows_pushdown<'a>( - predicate: Expr, - join_cols: impl IntoIterator, -) -> bool { - matches!(is_restrict_null_predicate(predicate, join_cols), Ok(true)) -} - /// Infer predicates from the pushed down predicates. /// /// Parameters diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 009a42a8d9b78..34fbec179a7fe 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -124,35 +124,31 @@ pub fn is_restrict_null_predicate<'a>( return Ok(false); } - let mode = null_restriction_eval_mode(); + if matches!( + null_restriction_eval_mode(), + NullRestrictionEvalMode::AuthoritativeOnly + ) { + return authoritative_restrict_null_predicate(predicate, join_cols); + } - match mode { - NullRestrictionEvalMode::AuthoritativeOnly => { - authoritative_restrict_null_predicate(predicate, join_cols) - } - NullRestrictionEvalMode::Auto => { - if let Some(is_restricting) = - null_restriction::syntactic_restrict_null_predicate( - &predicate, &join_cols, - ) - { - #[cfg(debug_assertions)] - { - let authoritative = authoritative_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - )?; - debug_assert_eq!( - is_restricting, authoritative, - "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" - ); - } - Ok(is_restricting) - } else { - authoritative_restrict_null_predicate(predicate, join_cols) - } + if let Some(is_restricting) = + null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols) + { + #[cfg(debug_assertions)] + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); } + return Ok(is_restricting); } + + authoritative_restrict_null_predicate(predicate, join_cols) } /// Determines if an expression will always evaluate to null. diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index dd92bbfea03ab..28e32a26ed1da 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -32,87 +32,6 @@ enum NullSubstitutionValue { Boolean(bool), } -/// Evaluates a subset of SQL three-valued logic over null substitution values. -/// -/// Lattice used by the syntactic fast path: -/// - `Boolean(true|false)`: exact logical value -/// - `Null`: exact SQL unknown/null value -/// - `NonNull`: known to be not null, but exact value is unknown -/// -/// `NonNull` is intentionally conservative for logical operators. For example, -/// `NonNull AND true` could be either `true` or `false`, so the result remains -/// unknown to the syntactic evaluator (`None`) rather than pretending to know. -#[derive(Default)] -struct SqlThreeValuedEvaluator; - -impl SqlThreeValuedEvaluator { - fn not(&self, value: Option) -> Option { - match value { - Some(NullSubstitutionValue::Boolean(value)) => { - Some(NullSubstitutionValue::Boolean(!value)) - } - Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), - Some(NullSubstitutionValue::NonNull) | None => None, - } - } - - fn and( - &self, - left: Option, - right: Option, - ) -> Option { - if matches!(left, Some(NullSubstitutionValue::Boolean(false))) - || matches!(right, Some(NullSubstitutionValue::Boolean(false))) - { - return Some(NullSubstitutionValue::Boolean(false)); - } - - match (left, right) { - (Some(NullSubstitutionValue::Boolean(true)), value) - | (value, Some(NullSubstitutionValue::Boolean(true))) => value, - (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { - Some(NullSubstitutionValue::Null) - } - (Some(NullSubstitutionValue::NonNull), _) - | (_, Some(NullSubstitutionValue::NonNull)) - | (None, _) - | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } - } - } - - fn or( - &self, - left: Option, - right: Option, - ) -> Option { - if matches!(left, Some(NullSubstitutionValue::Boolean(true))) - || matches!(right, Some(NullSubstitutionValue::Boolean(true))) - { - return Some(NullSubstitutionValue::Boolean(true)); - } - - match (left, right) { - (Some(NullSubstitutionValue::Boolean(false)), value) - | (value, Some(NullSubstitutionValue::Boolean(false))) => value, - (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { - Some(NullSubstitutionValue::Null) - } - (Some(NullSubstitutionValue::NonNull), _) - | (_, Some(NullSubstitutionValue::NonNull)) - | (None, _) - | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } - } - } -} - pub(super) fn syntactic_restrict_null_predicate( predicate: &Expr, join_cols: &HashSet<&Column>, @@ -144,75 +63,104 @@ fn contains_null( .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) } +fn not(value: Option) -> Option { + match value { + Some(NullSubstitutionValue::Boolean(value)) => { + Some(NullSubstitutionValue::Boolean(!value)) + } + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +fn binary_boolean_value( + left: Option, + right: Option, + when_short_circuit: bool, +) -> Option { + let short_circuit = Some(NullSubstitutionValue::Boolean(when_short_circuit)); + let identity = Some(NullSubstitutionValue::Boolean(!when_short_circuit)); + + if left == short_circuit || right == short_circuit { + return short_circuit; + } + + match (left, right) { + (value, other) if value == identity => other, + (other, value) if value == identity => other, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + fn null_check_value( value: Option, - when_non_null: bool, + is_not_null: bool, ) -> Option { match value { - Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Boolean(false)), + Some(NullSubstitutionValue::Null) => { + Some(NullSubstitutionValue::Boolean(!is_not_null)) + } Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { - Some(NullSubstitutionValue::Boolean(when_non_null)) + Some(NullSubstitutionValue::Boolean(is_not_null)) } None => None, } } +fn null_if_contains_null( + values: impl IntoIterator>, +) -> Option { + contains_null(values).then_some(NullSubstitutionValue::Null) +} + fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, ) -> Option { - let evaluator = SqlThreeValuedEvaluator; - match expr { Expr::Alias(alias) => { syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) } - Expr::Column(column) => { - if join_cols.contains(column) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } + Expr::Column(column) => join_cols + .contains(column) + .then_some(NullSubstitutionValue::Null), Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), Expr::Not(expr) => { - evaluator.not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) } - Expr::IsNull(expr) => evaluator.not(null_check_value( + Expr::IsNull(expr) => null_check_value( syntactic_null_substitution_value(expr.as_ref(), join_cols), - true, - )), + false, + ), Expr::IsNotNull(expr) => null_check_value( syntactic_null_substitution_value(expr.as_ref(), join_cols), true, ), - Expr::Between(between) => { - if contains_null([ - syntactic_null_substitution_value(between.expr.as_ref(), join_cols), - syntactic_null_substitution_value(between.low.as_ref(), join_cols), - syntactic_null_substitution_value(between.high.as_ref(), join_cols), - ]) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } + Expr::Between(between) => null_if_contains_null([ + syntactic_null_substitution_value(between.expr.as_ref(), join_cols), + syntactic_null_substitution_value(between.low.as_ref(), join_cols), + syntactic_null_substitution_value(between.high.as_ref(), join_cols), + ]), Expr::Cast(cast) => strict_null_passthrough(cast.expr.as_ref(), join_cols), Expr::TryCast(try_cast) => { strict_null_passthrough(try_cast.expr.as_ref(), join_cols) } Expr::Negative(expr) => strict_null_passthrough(expr.as_ref(), join_cols), - Expr::Like(like) | Expr::SimilarTo(like) => { - if contains_null([ - syntactic_null_substitution_value(like.expr.as_ref(), join_cols), - syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), - ]) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } + Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ + syntactic_null_substitution_value(like.expr.as_ref(), join_cols), + syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), + ]), Expr::Exists { .. } | Expr::InList(_) | Expr::InSubquery(_) @@ -240,12 +188,10 @@ fn syntactic_null_substitution_value( } fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { - if value.is_null() { - NullSubstitutionValue::Null - } else if let ScalarValue::Boolean(Some(value)) = value { - NullSubstitutionValue::Boolean(*value) - } else { - NullSubstitutionValue::NonNull + match value { + _ if value.is_null() => NullSubstitutionValue::Null, + ScalarValue::Boolean(Some(value)) => NullSubstitutionValue::Boolean(*value), + _ => NullSubstitutionValue::NonNull, } } @@ -253,14 +199,11 @@ fn strict_null_passthrough( expr: &Expr, join_cols: &HashSet<&Column>, ) -> Option { - if matches!( + matches!( syntactic_null_substitution_value(expr, join_cols), Some(NullSubstitutionValue::Null) - ) { - Some(NullSubstitutionValue::Null) - } else { - None - } + ) + .then_some(NullSubstitutionValue::Null) } fn syntactic_binary_value( @@ -269,11 +212,10 @@ fn syntactic_binary_value( ) -> Option { let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); - let evaluator = SqlThreeValuedEvaluator; match binary_expr.op { - Operator::And => evaluator.and(left, right), - Operator::Or => evaluator.or(left, right), + Operator::And => binary_boolean_value(left, right, false), + Operator::Or => binary_boolean_value(left, right, true), Operator::Eq | Operator::NotEq | Operator::Lt @@ -312,13 +254,7 @@ fn syntactic_binary_value( | Operator::Question | Operator::QuestionAnd | Operator::QuestionPipe - | Operator::Colon => { - if contains_null([left, right]) { - Some(NullSubstitutionValue::Null) - } else { - None - } - } + | Operator::Colon => null_if_contains_null([left, right]), Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => None, } } From 91acb3f54772ccd7fe575f1a6f0796782cc01394 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 11:59:02 +0800 Subject: [PATCH 26/63] Make NullRestrictionEvalMode test-only Add #[cfg(test)] to NullRestrictionEvalMode and its methods. Ensure that non-test runtime maintains Auto behavior as before. Test path retains authoritative override and supports existing tests for null restriction evaluation. --- datafusion/optimizer/src/utils.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 34fbec179a7fe..79dcefb312c3c 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -28,6 +28,7 @@ use std::sync::Mutex; use crate::analyzer::type_coercion::TypeCoercionRewriter; /// Null restriction evaluation mode for optimizer tests. +#[cfg(test)] #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub(crate) enum NullRestrictionEvalMode { Auto, @@ -43,15 +44,9 @@ pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalM *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() = mode; } +#[cfg(test)] fn null_restriction_eval_mode() -> NullRestrictionEvalMode { - #[cfg(test)] - { - *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() - } - #[cfg(not(test))] - { - NullRestrictionEvalMode::Auto - } + *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() } use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; @@ -124,6 +119,7 @@ pub fn is_restrict_null_predicate<'a>( return Ok(false); } + #[cfg(test)] if matches!( null_restriction_eval_mode(), NullRestrictionEvalMode::AuthoritativeOnly From 6cce56aae8b37cdc6a25be019ea969e71df52f4d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 20 Mar 2026 13:01:44 +0800 Subject: [PATCH 27/63] amend benchmark --- .../core/benches/sql_planner_extended.rs | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index 012587b064700..73141bd8668e4 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -260,26 +260,6 @@ fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) - query } -fn build_case_heavy_left_join_df_with_push_down_filter( - rt: &Runtime, - predicate_count: usize, - case_depth: usize, - push_down_filter_enabled: bool, -) -> DataFrame { - let ctx = SessionContext::new(); - register_string_table(&ctx, 100, 1000); - if !push_down_filter_enabled { - let removed = ctx.remove_optimizer_rule("push_down_filter"); - assert!( - removed, - "push_down_filter rule should be present in the default optimizer" - ); - } - - let query = build_case_heavy_left_join_query(predicate_count, case_depth); - rt.block_on(async { ctx.sql(&query).await.unwrap() }) -} - fn build_non_case_left_join_query( predicate_count: usize, nesting_depth: usize, From 2832040b8abff81db642e476d35005c5618d57c5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 15:22:17 +0800 Subject: [PATCH 28/63] Optimize join predicate handling and add regression test Refactor `infer_join_predicates_impl` to compute `column_refs()` once per predicate and reuse the set for building the replace map. For non-inner joins, skip null-restriction evaluation when predicates mix replaceable join-key references with other columns. Add a regression test in `push_down_filter.rs` to verify that left-join filters referencing both `test2.a` and `test.b` do not get incorrectly inferred onto the preserved side. --- datafusion/optimizer/src/push_down_filter.rs | 47 +++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 14286a4480835..c1a0201cc4ffa 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -740,9 +740,10 @@ fn infer_join_predicates_impl< inferred_predicates: &mut InferredPredicates, ) -> Result<()> { for predicate in input_predicates { + let column_refs = predicate.column_refs(); let mut join_cols_to_replace = HashMap::new(); - for &col in &predicate.column_refs() { + for &col in &column_refs { for (l, r) in join_col_keys.iter() { if ENABLE_LEFT_TO_RIGHT && col == *l { join_cols_to_replace.insert(col, *r); @@ -758,6 +759,15 @@ fn infer_join_predicates_impl< continue; } + // For non-inner joins, predicates that reference any non-replaceable + // columns cannot be inferred on the other side. Skip the null-restriction + // helper entirely in that common mixed-reference case. + if !inferred_predicates.is_inner_join + && join_cols_to_replace.len() != column_refs.len() + { + continue; + } + inferred_predicates .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; } @@ -2808,6 +2818,41 @@ mod tests { ) } + /// mixed post-left-join predicates that reference a join key plus a + /// non-join column should not be inferred to the preserved side + #[test] + fn filter_using_left_join_with_mixed_join_key_and_non_join_refs() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("c")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(add(col("test2.a"), col("test.b")).gt(lit(1i64)))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a + test.b > Int64(1) + Left Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test + Projection: test2.a, test2.c + TableScan: test2 + " + ) + } + /// post-right-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_right_join_on_common() -> Result<()> { From e09533fb7e1eea6a0c83bce2db68e59a2ceb40c0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 15:25:23 +0800 Subject: [PATCH 29/63] Optimize candidate-building flow in join predicates Conduct cheap operations first in infer_join_predicates_impl. Scan column_refs() once, record potential join-key replacements, and check for non-replaceable column references early. Skip replacement HashMap setup and null-restriction evaluation for non-inner joins to reduce cost for mixed-reference predicates. --- datafusion/optimizer/src/push_down_filter.rs | 22 +++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c1a0201cc4ffa..bd714927c28a4 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -741,33 +741,41 @@ fn infer_join_predicates_impl< ) -> Result<()> { for predicate in input_predicates { let column_refs = predicate.column_refs(); - let mut join_cols_to_replace = HashMap::new(); + let mut join_col_replacements = Vec::new(); + let mut has_non_replaceable_refs = false; for &col in &column_refs { + let mut replacement = None; + for (l, r) in join_col_keys.iter() { if ENABLE_LEFT_TO_RIGHT && col == *l { - join_cols_to_replace.insert(col, *r); + replacement = Some((col, *r)); break; } if ENABLE_RIGHT_TO_LEFT && col == *r { - join_cols_to_replace.insert(col, *l); + replacement = Some((col, *l)); break; } } + + if let Some(replacement) = replacement { + join_col_replacements.push(replacement); + } else { + has_non_replaceable_refs = true; + } } - if join_cols_to_replace.is_empty() { + if join_col_replacements.is_empty() { continue; } // For non-inner joins, predicates that reference any non-replaceable // columns cannot be inferred on the other side. Skip the null-restriction // helper entirely in that common mixed-reference case. - if !inferred_predicates.is_inner_join - && join_cols_to_replace.len() != column_refs.len() - { + if !inferred_predicates.is_inner_join && has_non_replaceable_refs { continue; } + let join_cols_to_replace = join_col_replacements.into_iter().collect(); inferred_predicates .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; } From 69864d6f94e2450fd328951098912832064abc20 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 15:31:00 +0800 Subject: [PATCH 30/63] Enhance join query to generate complex CASE predicates Update the build_case_heavy_left_join_query to produce predicates that consistently reference join keys with l.c0 or r.c0. Include additional non-join payload columns from both sides, while maintaining a CASE-heavy cross-side structure. The new pattern starts with a conditional check on l.c0 and constructs deeper CASE layers to integrate other relevant columns from both sides. --- datafusion/core/benches/sql_planner_extended.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index 73141bd8668e4..2e83def0e99ba 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -245,12 +245,17 @@ fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) - query.push_str(" AND "); } - let mut expr = format!("length(l.c{})", i % 20); + let left_payload_col = (i % 19) + 1; + let right_payload_col = ((i + 7) % 19) + 1; + let mut expr = format!( + "CASE WHEN l.c0 IS NOT NULL THEN length(l.c{left_payload_col}) ELSE length(r.c{right_payload_col}) END" + ); for depth in 0..case_depth { - let left_col = (i + depth + 1) % 20; - let right_col = (i + depth + 2) % 20; + let left_col = ((i + depth + 3) % 19) + 1; + let right_col = ((i + depth + 11) % 19) + 1; + let join_key_ref = if (i + depth) % 2 == 0 { "l.c0" } else { "r.c0" }; expr = format!( - "CASE WHEN l.c{left_col} IS NOT NULL THEN {expr} ELSE length(r.c{right_col}) END" + "CASE WHEN {join_key_ref} IS NOT NULL THEN {expr} ELSE CASE WHEN l.c{left_col} IS NOT NULL THEN length(l.c{left_col}) ELSE length(r.c{right_col}) END END" ); } From 23e38a19e5476f829eed1926cf8ae660fc435656 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 15:38:12 +0800 Subject: [PATCH 31/63] Add guard helper for CASE-heavy left join inference Implement find_filter_predicates to extract top-level WHERE conjuncts from the benchmark plan. Introduce assert_case_heavy_left_join_inference_candidates to ensure generated predicates meet criteria: include l.c0 or r.c0, contain at least one non-join column, and match expected predicate counts. This enhancement checks the inference-candidate shape during CASE-heavy benchmark dataframe setup. --- .../core/benches/sql_planner_extended.rs | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index 2e83def0e99ba..767134bb5bafd 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -24,8 +24,10 @@ use criterion::{ }; use datafusion::prelude::{DataFrame, SessionContext}; use datafusion_catalog::MemTable; -use datafusion_common::ScalarValue; +use datafusion_common::{Column, ScalarValue}; use datafusion_expr::Expr::Literal; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::{cast, col, lit, not, try_cast, when}; use datafusion_functions::expr_fn::{ btrim, length, regexp_like, regexp_replace, to_timestamp, upper, @@ -226,7 +228,9 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { fn build_case_heavy_left_join_df(ctx: &SessionContext, rt: &Runtime) -> DataFrame { register_string_table(ctx, 100, 1000); let query = build_case_heavy_left_join_query(30, 1); - rt.block_on(async { ctx.sql(&query).await.unwrap() }) + let df = rt.block_on(async { ctx.sql(&query).await.unwrap() }); + assert_case_heavy_left_join_inference_candidates(&df, 30); + df } fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) -> String { @@ -324,13 +328,15 @@ fn build_case_heavy_left_join_df_with_push_down_filter( case_depth: usize, push_down_filter_enabled: bool, ) -> DataFrame { - build_left_join_df_with_push_down_filter( + let df = build_left_join_df_with_push_down_filter( rt, build_case_heavy_left_join_query, predicate_count, case_depth, push_down_filter_enabled, - ) + ); + assert_case_heavy_left_join_inference_candidates(&df, predicate_count); + df } fn build_non_case_left_join_df_with_push_down_filter( @@ -348,6 +354,39 @@ fn build_non_case_left_join_df_with_push_down_filter( ) } +fn find_filter_predicates(plan: &LogicalPlan) -> Vec { + match plan { + LogicalPlan::Filter(filter) => split_conjunction_owned(filter.predicate.clone()), + LogicalPlan::Projection(projection) => find_filter_predicates(projection.input.as_ref()), + other => panic!("expected benchmark query plan to contain a Filter, found {other:?}"), + } +} + +fn assert_case_heavy_left_join_inference_candidates( + df: &DataFrame, + expected_predicate_count: usize, +) { + let predicates = find_filter_predicates(df.logical_plan()); + assert_eq!(predicates.len(), expected_predicate_count); + + let left_join_key = Column::from_qualified_name("l.c0"); + let right_join_key = Column::from_qualified_name("r.c0"); + + for predicate in predicates { + let column_refs = predicate.column_refs(); + assert!( + column_refs.contains(&&left_join_key) || column_refs.contains(&&right_join_key), + "benchmark predicate should reference a join key: {predicate}" + ); + assert!( + column_refs + .iter() + .any(|col| **col != left_join_key && **col != right_join_key), + "benchmark predicate should reference a non-join column: {predicate}" + ); + } +} + fn include_full_push_down_filter_sweep() -> bool { env::var("DATAFUSION_PUSH_DOWN_FILTER_FULL_SWEEP") .map(|value| value == "1" || value.eq_ignore_ascii_case("true")) From 3e65e225a3b58721566dad8d31ce0715283cb203 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 15:59:52 +0800 Subject: [PATCH 32/63] Refactor promotion filter for join conditions Narrow down the can_promote_post_join_filter_to_join_condition to allow promotions when join.on is non-empty. Block promotion for potentially disappearing scalar sides while permitting a narrow exception for scalar-subquery-shaped sides identified via SubqueryAlias with a guaranteed single row. This restores subquery_filter_with_cast shape without reintroducing the explicit Cross Join regression coverage. --- datafusion/optimizer/src/push_down_filter.rs | 58 ++++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index bd714927c28a4..ee88840a7c2cd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -35,7 +35,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::WindowFunction; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; +use datafusion_expr::logical_plan::{Aggregate, Join, JoinType, LogicalPlan, TableScan, Union}; use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; @@ -518,15 +518,61 @@ fn push_down_all_join( /// Returns true when post-join filters are allowed to be promoted to join conditions. /// /// Protection is necessary for scalar-side joins and cross joins to avoid incorrectly -/// rewriting a post-join filter into the join condition when one side is empty or -/// limited to at most one row (`max_rows() == Some(1)`). +/// rewriting a post-join filter into the join condition when one side may disappear +/// entirely, even though `max_rows() == Some(1)`. /// /// - `join.on` non-empty means existing join predicates already exist; promotion is safe. -/// - if neither side is scalar (`max_rows() == Some(1)`), promotion is safe. +/// - if a scalar side is a scalar-subquery-shaped input that is provably exactly one +/// row, promotion is safe. +/// - otherwise, keep the filter above the join. fn can_promote_post_join_filter_to_join_condition(join: &Join) -> bool { !join.on.is_empty() - || !(matches!(join.left.max_rows(), Some(1)) - || matches!(join.right.max_rows(), Some(1))) + || !join_side_may_disappear(join.left.as_ref()) + && !join_side_may_disappear(join.right.as_ref()) +} + +/// Returns true when a plan can produce at most one row but is not guaranteed +/// to produce exactly one row. +fn join_side_may_disappear(plan: &LogicalPlan) -> bool { + matches!(plan.max_rows(), Some(1)) && !is_safe_scalar_subquery_side(plan) +} + +/// Returns true for the scalar-subquery-shaped inputs where post-join filter +/// promotion should remain legal. +fn is_safe_scalar_subquery_side(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Projection(projection) => { + is_safe_scalar_subquery_side(projection.input.as_ref()) + } + LogicalPlan::Repartition(repartition) => { + is_safe_scalar_subquery_side(repartition.input.as_ref()) + } + LogicalPlan::Sort(sort) => is_safe_scalar_subquery_side(sort.input.as_ref()), + LogicalPlan::SubqueryAlias(subquery_alias) => { + returns_exactly_one_row(subquery_alias.input.as_ref()) + } + _ => false, + } +} + +/// Returns true when the plan is guaranteed to produce exactly one row. +fn returns_exactly_one_row(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Projection(projection) => returns_exactly_one_row(projection.input.as_ref()), + LogicalPlan::SubqueryAlias(subquery_alias) => { + returns_exactly_one_row(subquery_alias.input.as_ref()) + } + LogicalPlan::Repartition(repartition) => { + returns_exactly_one_row(repartition.input.as_ref()) + } + LogicalPlan::Sort(sort) => returns_exactly_one_row(sort.input.as_ref()), + LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { + group_expr + .iter() + .all(|expr| matches!(expr, Expr::Literal(_, _))) + } + _ => false, + } } fn push_down_join( From 1b8940913b936c9798d825a0bfe76fe8cbfed113 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 16:01:57 +0800 Subject: [PATCH 33/63] Refine returns_exactly_one_row helper behavior Update the returns_exactly_one_row helper to take precedence over max_rows() in specific code paths. Ensure that global aggregates without grouping keys return true, filters over aggregates return false, and limits with fetch=1 also return false. This maintains the distinction between "exactly one row" and "at most one row" to address the underlying regression effectively. --- datafusion/optimizer/src/push_down_filter.rs | 31 ++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ee88840a7c2cd..ec66fab0a3799 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3723,6 +3723,37 @@ mod tests { ) } + #[test] + fn returns_exactly_one_row_for_global_aggregate() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .aggregate(Vec::::new(), vec![avg(col("a"))])? + .build()?; + + assert!(returns_exactly_one_row(&plan)); + Ok(()) + } + + #[test] + fn returns_exactly_one_row_is_false_for_filtered_global_aggregate() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .aggregate(Vec::::new(), vec![avg(col("a"))])? + .filter(col("avg(test.a)").gt(lit(0i64)))? + .build()?; + + assert!(!returns_exactly_one_row(&plan)); + Ok(()) + } + + #[test] + fn returns_exactly_one_row_is_false_for_limit_fetch_one() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .limit(0, Some(1))? + .build()?; + + assert!(!returns_exactly_one_row(&plan)); + Ok(()) + } + #[test] fn left_semi_join() -> Result<()> { let left = test_table_scan_with_name("test1")?; From 0a1feaac98ee24fcf1e9f2b69869c90cd60fc1f3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 16:05:19 +0800 Subject: [PATCH 34/63] Allow post-join filter promotion based on conditions Enhance the logic for promoting post-join filters in the join condition. Now, promotion is allowed if the join.on clause is non-empty. For empty joins, each scalar side must explicitly pass the scalar_side_can_promote_post_join_filter check. Non-scalar sides remain unaffected, while scalar sides can only use the existing safe exact-one-row scalar-subquery path for promotion. --- datafusion/optimizer/src/push_down_filter.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ec66fab0a3799..80b1b0ecd6e11 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -527,14 +527,18 @@ fn push_down_all_join( /// - otherwise, keep the filter above the join. fn can_promote_post_join_filter_to_join_condition(join: &Join) -> bool { !join.on.is_empty() - || !join_side_may_disappear(join.left.as_ref()) - && !join_side_may_disappear(join.right.as_ref()) + || scalar_side_can_promote_post_join_filter(join.left.as_ref()) + && scalar_side_can_promote_post_join_filter(join.right.as_ref()) } -/// Returns true when a plan can produce at most one row but is not guaranteed -/// to produce exactly one row. -fn join_side_may_disappear(plan: &LogicalPlan) -> bool { - matches!(plan.max_rows(), Some(1)) && !is_safe_scalar_subquery_side(plan) +/// Returns true when a non-scalar side is unrestricted, or when a scalar side is +/// a safe exact-one-row scalar-subquery shape. +fn scalar_side_can_promote_post_join_filter(plan: &LogicalPlan) -> bool { + !is_scalar_side(plan) || is_safe_scalar_subquery_side(plan) +} + +fn is_scalar_side(plan: &LogicalPlan) -> bool { + matches!(plan.max_rows(), Some(1)) } /// Returns true for the scalar-subquery-shaped inputs where post-join filter From 52713fa343873db453a781f76c451963e2b3a543 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 16:14:21 +0800 Subject: [PATCH 35/63] Add support for limited row cross joins Implement cross_join_with_at_most_one_row_side_keeps_post_join_filter, which utilizes a Limit(fetch=1) subquery alias on the right side. This new plan allows for at most one row, maintaining the post-join filter above the Cross Join. This change complements existing aggregate-based cross-join regression and enhances safety against regression in unsafe cases. --- datafusion/optimizer/src/push_down_filter.rs | 30 ++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 80b1b0ecd6e11..8059ba7e29f3e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3727,6 +3727,36 @@ mod tests { ) } + #[test] + fn cross_join_with_at_most_one_row_side_keeps_post_join_filter() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a"), col("b")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a")])? + .limit(0, Some(1))? + .alias("sq")? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("test.b").gt(col("sq.a")))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > sq.a + Cross Join: + Projection: test.a, test.b + TableScan: test + SubqueryAlias: sq + Limit: skip=0, fetch=1 + Projection: test1.a + TableScan: test1 + " + ) + } + #[test] fn returns_exactly_one_row_for_global_aggregate() -> Result<()> { let plan = LogicalPlanBuilder::from(test_table_scan()?) From 4f90984e6294e89187a932d92bfd6d69eda47f59 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 16:15:40 +0800 Subject: [PATCH 36/63] Add test for exact one-row subquery promotion Introduce a focused safe-case test to ensure that a scalar subquery producing exactly one row is correctly promoted into an Inner Join. This change provides balanced coverage alongside existing unsafe-case tests, confirming that the post-join filter behaves as intended in both safe and unsafe scenarios. --- datafusion/optimizer/src/push_down_filter.rs | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8059ba7e29f3e..cf28506105114 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3727,6 +3727,35 @@ mod tests { ) } + #[test] + fn cross_join_with_exact_one_row_subquery_promotes_post_join_filter() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a"), col("b")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a")])? + .aggregate(Vec::::new(), vec![avg(col("a")).alias("avg_a")])? + .alias("sq")? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("test.b").gt(col("sq.avg_a")))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: Filter: test.b > sq.avg_a + Projection: test.a, test.b + TableScan: test + SubqueryAlias: sq + Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_a]] + Projection: test1.a + TableScan: test1 + " + ) + } + #[test] fn cross_join_with_at_most_one_row_side_keeps_post_join_filter() -> Result<()> { let left = LogicalPlanBuilder::from(test_table_scan()?) From 33fcb56594f1255a16689013b5ddcd825aef9915 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 16:28:22 +0800 Subject: [PATCH 37/63] Remove unrelated join-promotion behavior from PushDownFilter Eliminate the can_promote_post_join_filter_to_join_condition gate and its helper functions. Restore the original inner-join promotion condition in the hot loop and remove guard-specific tests that were added solely for the previous behavior. --- datafusion/optimizer/src/push_down_filter.rs | 191 +------------------ 1 file changed, 3 insertions(+), 188 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index cf28506105114..aaaaa5ca48222 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -35,7 +35,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::WindowFunction; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::logical_plan::{Aggregate, Join, JoinType, LogicalPlan, TableScan, Union}; +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; @@ -431,10 +431,7 @@ fn push_down_all_join( left_push.push(predicate); } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); - } else if is_inner_join - && can_promote_post_join_filter_to_join_condition(&join) - && can_evaluate_as_join_condition(&predicate)? - { + } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate // and convert to the join on condition join_conditions.push(predicate); @@ -515,70 +512,6 @@ fn push_down_all_join( Ok(Transformed::yes(plan)) } -/// Returns true when post-join filters are allowed to be promoted to join conditions. -/// -/// Protection is necessary for scalar-side joins and cross joins to avoid incorrectly -/// rewriting a post-join filter into the join condition when one side may disappear -/// entirely, even though `max_rows() == Some(1)`. -/// -/// - `join.on` non-empty means existing join predicates already exist; promotion is safe. -/// - if a scalar side is a scalar-subquery-shaped input that is provably exactly one -/// row, promotion is safe. -/// - otherwise, keep the filter above the join. -fn can_promote_post_join_filter_to_join_condition(join: &Join) -> bool { - !join.on.is_empty() - || scalar_side_can_promote_post_join_filter(join.left.as_ref()) - && scalar_side_can_promote_post_join_filter(join.right.as_ref()) -} - -/// Returns true when a non-scalar side is unrestricted, or when a scalar side is -/// a safe exact-one-row scalar-subquery shape. -fn scalar_side_can_promote_post_join_filter(plan: &LogicalPlan) -> bool { - !is_scalar_side(plan) || is_safe_scalar_subquery_side(plan) -} - -fn is_scalar_side(plan: &LogicalPlan) -> bool { - matches!(plan.max_rows(), Some(1)) -} - -/// Returns true for the scalar-subquery-shaped inputs where post-join filter -/// promotion should remain legal. -fn is_safe_scalar_subquery_side(plan: &LogicalPlan) -> bool { - match plan { - LogicalPlan::Projection(projection) => { - is_safe_scalar_subquery_side(projection.input.as_ref()) - } - LogicalPlan::Repartition(repartition) => { - is_safe_scalar_subquery_side(repartition.input.as_ref()) - } - LogicalPlan::Sort(sort) => is_safe_scalar_subquery_side(sort.input.as_ref()), - LogicalPlan::SubqueryAlias(subquery_alias) => { - returns_exactly_one_row(subquery_alias.input.as_ref()) - } - _ => false, - } -} - -/// Returns true when the plan is guaranteed to produce exactly one row. -fn returns_exactly_one_row(plan: &LogicalPlan) -> bool { - match plan { - LogicalPlan::Projection(projection) => returns_exactly_one_row(projection.input.as_ref()), - LogicalPlan::SubqueryAlias(subquery_alias) => { - returns_exactly_one_row(subquery_alias.input.as_ref()) - } - LogicalPlan::Repartition(repartition) => { - returns_exactly_one_row(repartition.input.as_ref()) - } - LogicalPlan::Sort(sort) => returns_exactly_one_row(sort.input.as_ref()), - LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { - group_expr - .iter() - .all(|expr| matches!(expr, Expr::Literal(_, _))) - } - _ => false, - } -} - fn push_down_join( join: Join, parent_predicate: Option<&Expr>, @@ -1562,7 +1495,7 @@ mod tests { use crate::simplify_expressions::SimplifyExpressions; use crate::test::udfs::leaf_udf_expr; use crate::test::*; - use datafusion_expr::test::function_stub::{avg, sum}; + use datafusion_expr::test::function_stub::sum; use insta::assert_snapshot; use super::*; @@ -3699,124 +3632,6 @@ mod tests { ) } - #[test] - fn cross_join_with_scalar_side_keeps_post_join_filter() -> Result<()> { - let left = LogicalPlanBuilder::from(test_table_scan()?) - .project(vec![col("a"), col("b")])? - .build()?; - let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) - .project(vec![col("a")])? - .aggregate(Vec::::new(), vec![avg(col("a")).alias("avg_a")])? - .build()?; - let plan = LogicalPlanBuilder::from(left) - .cross_join(right)? - .filter(col("test.b").gt(col("avg_a")))? - .build()?; - - assert_optimized_plan_equal!( - plan, - @r" - Filter: test.b > avg_a - Cross Join: - Projection: test.a, test.b - TableScan: test - Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_a]] - Projection: test1.a - TableScan: test1 - " - ) - } - - #[test] - fn cross_join_with_exact_one_row_subquery_promotes_post_join_filter() -> Result<()> { - let left = LogicalPlanBuilder::from(test_table_scan()?) - .project(vec![col("a"), col("b")])? - .build()?; - let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) - .project(vec![col("a")])? - .aggregate(Vec::::new(), vec![avg(col("a")).alias("avg_a")])? - .alias("sq")? - .build()?; - let plan = LogicalPlanBuilder::from(left) - .cross_join(right)? - .filter(col("test.b").gt(col("sq.avg_a")))? - .build()?; - - assert_optimized_plan_equal!( - plan, - @r" - Inner Join: Filter: test.b > sq.avg_a - Projection: test.a, test.b - TableScan: test - SubqueryAlias: sq - Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_a]] - Projection: test1.a - TableScan: test1 - " - ) - } - - #[test] - fn cross_join_with_at_most_one_row_side_keeps_post_join_filter() -> Result<()> { - let left = LogicalPlanBuilder::from(test_table_scan()?) - .project(vec![col("a"), col("b")])? - .build()?; - let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) - .project(vec![col("a")])? - .limit(0, Some(1))? - .alias("sq")? - .build()?; - let plan = LogicalPlanBuilder::from(left) - .cross_join(right)? - .filter(col("test.b").gt(col("sq.a")))? - .build()?; - - assert_optimized_plan_equal!( - plan, - @r" - Filter: test.b > sq.a - Cross Join: - Projection: test.a, test.b - TableScan: test - SubqueryAlias: sq - Limit: skip=0, fetch=1 - Projection: test1.a - TableScan: test1 - " - ) - } - - #[test] - fn returns_exactly_one_row_for_global_aggregate() -> Result<()> { - let plan = LogicalPlanBuilder::from(test_table_scan()?) - .aggregate(Vec::::new(), vec![avg(col("a"))])? - .build()?; - - assert!(returns_exactly_one_row(&plan)); - Ok(()) - } - - #[test] - fn returns_exactly_one_row_is_false_for_filtered_global_aggregate() -> Result<()> { - let plan = LogicalPlanBuilder::from(test_table_scan()?) - .aggregate(Vec::::new(), vec![avg(col("a"))])? - .filter(col("avg(test.a)").gt(lit(0i64)))? - .build()?; - - assert!(!returns_exactly_one_row(&plan)); - Ok(()) - } - - #[test] - fn returns_exactly_one_row_is_false_for_limit_fetch_one() -> Result<()> { - let plan = LogicalPlanBuilder::from(test_table_scan()?) - .limit(0, Some(1))? - .build()?; - - assert!(!returns_exactly_one_row(&plan)); - Ok(()) - } - #[test] fn left_semi_join() -> Result<()> { let left = test_table_scan_with_name("test1")?; From 4215b127480753b185a3c22bd9913ffbd51c3f50 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 17:25:24 +0800 Subject: [PATCH 38/63] Add optimizer-level test for scalar subquery filter Introduce a focused optimizer-level reproducer for the scalar-subquery Cross Join + Filter scenario. This test case highlights the bug in PushDownFilter, showing that it incorrectly optimizes to Inner Join instead of retaining the Filter above the Cross Join. This change helps ensure future stability and correctness. --- datafusion/optimizer/src/push_down_filter.rs | 54 +++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index aaaaa5ca48222..cb9910aac3106 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1495,7 +1495,7 @@ mod tests { use crate::simplify_expressions::SimplifyExpressions; use crate::test::udfs::leaf_udf_expr; use crate::test::*; - use datafusion_expr::test::function_stub::sum; + use datafusion_expr::test::function_stub::{avg, sum}; use insta::assert_snapshot; use super::*; @@ -2419,6 +2419,58 @@ mod tests { ) } + #[test] + #[ignore = "FIX_06 step(1): reproduces current scalar-subquery cross-join promotion regression"] + fn window_over_scalar_subquery_cross_join_keeps_filter_above_join() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? + .alias("s")? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a").alias("acctbal")])? + .aggregate( + Vec::::new(), + vec![avg(col("acctbal")).alias("avg_acctbal")], + )? + .alias("__scalar_sq_1")? + .build()?; + + let window = Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::row_number::row_number_udwf(), + ), + vec![], + )) + .partition_by(vec![col("s.nation")]) + .order_by(vec![col("s.acctbal").sort(false, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? + .project(vec![col("s.nation"), col("s.acctbal")])? + .window(vec![window])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[row_number() PARTITION BY [s.nation] ORDER BY [s.acctbal DESC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + Projection: s.nation, s.acctbal + Filter: s.acctbal > __scalar_sq_1.avg_acctbal + Cross Join: + SubqueryAlias: s + Projection: test.a AS nation, test.b AS acctbal + TableScan: test + SubqueryAlias: __scalar_sq_1 + Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_acctbal]] + Projection: test1.a AS acctbal + TableScan: test1 + " + ) + } + /// verifies that filters with the same columns are correctly placed #[test] fn filter_2_breaks_limits() -> Result<()> { From 526d34664cbd21f7664d09d9dbf46e481dd76dbf Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 17:27:44 +0800 Subject: [PATCH 39/63] Add focused tests for cross join and scalar subquery Introduce two tests to identify the promotion path behind the regression. The tests validate that the cross join is represented as Inner Join with no join filters and that a scalar subquery condition is considered a candidate for join evaluation. These changes help pinpoint the branch in push_down_all_join affecting the transformation of the Cross Join and Filter into Inner Join. --- datafusion/optimizer/src/push_down_filter.rs | 27 ++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index cb9910aac3106..f221c07291c24 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -2471,6 +2471,33 @@ mod tests { ) } + #[test] + fn cross_join_builder_uses_inner_join_with_no_join_keys() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .cross_join(test_table_scan_with_name("test1")?)? + .build()?; + + let LogicalPlan::Join(join) = plan else { + panic!("expected join plan"); + }; + + assert_eq!(join.join_type, JoinType::Inner); + assert!(join.on.is_empty()); + assert!(join.filter.is_none()); + + Ok(()) + } + + #[test] + fn scalar_subquery_cross_join_filter_is_treated_as_join_condition_candidate( + ) -> Result<()> { + let predicate = col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")); + + assert!(can_evaluate_as_join_condition(&predicate)?); + + Ok(()) + } + /// verifies that filters with the same columns are correctly placed #[test] fn filter_2_breaks_limits() -> Result<()> { From ac59a56520c14afb9220a9c36a81c76092191a50 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 17:32:54 +0800 Subject: [PATCH 40/63] Prevent join-condition promotion in specific cases Add a promotion guard to block join-condition promotion for empty-on cross-join shapes where one side is a scalar aggregate subquery and the other a derived relation, with predicates referencing both sides. This addresses the window/scalar-subquery regression while maintaining existing behavior for plain scalar subqueries. New optimizer coverage tests remain in place. --- datafusion/optimizer/src/push_down_filter.rs | 47 ++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f221c07291c24..a15e8f68b1863 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,6 +285,45 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } +fn strip_aliases_and_projections(plan: &LogicalPlan) -> &LogicalPlan { + match plan { + LogicalPlan::SubqueryAlias(subquery_alias) => { + strip_aliases_and_projections(subquery_alias.input.as_ref()) + } + LogicalPlan::Projection(projection) => { + strip_aliases_and_projections(projection.input.as_ref()) + } + _ => plan, + } +} + +fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool { + matches!( + strip_aliases_and_projections(plan), + LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty() + ) +} + +fn is_derived_relation(plan: &LogicalPlan) -> bool { + matches!(plan, LogicalPlan::SubqueryAlias(_)) +} + +fn should_keep_filter_above_cross_join(join: &Join, predicate: &Expr) -> bool { + if !join.on.is_empty() || join.filter.is_some() { + return false; + } + + let mut checker = ColumnChecker::new(join.left.schema(), join.right.schema()); + let references_both_sides = + !checker.is_left_only(predicate) && !checker.is_right_only(predicate); + + references_both_sides + && ((is_scalar_aggregate_subquery(join.left.as_ref()) + && is_derived_relation(join.right.as_ref())) + || (is_scalar_aggregate_subquery(join.right.as_ref()) + && is_derived_relation(join.left.as_ref()))) +} + /// examine OR clause to see if any useful clauses can be extracted and push down. /// extract at least one qual from each sub clauses of OR clause, then form the quals /// to new OR clause as predicate. @@ -431,7 +470,10 @@ fn push_down_all_join( left_push.push(predicate); } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); - } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { + } else if is_inner_join + && !should_keep_filter_above_cross_join(&join, &predicate) + && can_evaluate_as_join_condition(&predicate)? + { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate // and convert to the join on condition join_conditions.push(predicate); @@ -2420,7 +2462,6 @@ mod tests { } #[test] - #[ignore = "FIX_06 step(1): reproduces current scalar-subquery cross-join promotion regression"] fn window_over_scalar_subquery_cross_join_keeps_filter_above_join() -> Result<()> { let left = LogicalPlanBuilder::from(test_table_scan()?) .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? @@ -2464,7 +2505,7 @@ mod tests { Projection: test.a AS nation, test.b AS acctbal TableScan: test SubqueryAlias: __scalar_sq_1 - Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_acctbal]] + Aggregate: groupBy=[[]], aggr=[[avg(acctbal) AS avg_acctbal]] Projection: test1.a AS acctbal TableScan: test1 " From c09154fe37357213d3486e7ca17f73ce40129cf1 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 17:34:37 +0800 Subject: [PATCH 41/63] Add fast path for mixed-reference predicates in authoritative mode Ensure that mixed-reference predicates with only one join key lock in the desired behavior. The is_restrict_null_predicate(...) function returns false in both Auto and AuthoritativeOnly modes, thereby maintaining the fast path functionality in utils.rs even after applying the push_down_filter fix. --- datafusion/optimizer/src/utils.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 79dcefb312c3c..d8d45534a67f6 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -463,4 +463,28 @@ mod tests { Ok(()) } + + #[test] + fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode( + ) -> Result<()> { + let predicate = binary_expr(col("a"), Operator::Gt, col("b")); + let column_a = Column::from_name("a"); + + set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + let auto_result = + is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a))?; + + set_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + ); + let authoritative_only_result = + is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a))?; + + set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + + assert!(!auto_result, "{predicate}"); + assert!(!authoritative_only_result, "{predicate}"); + + Ok(()) + } } From b1d6cba6164f49b6917a72424435471f8a9fdd39 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 18:25:33 +0800 Subject: [PATCH 42/63] Enhance null-restriction eval mode for safety Make null-restriction eval mode panic-safe and reduce parallel test leakage. Introduce a scoped test helper with an RAII reset-on-drop guard to ensure proper restoration after test execution. Improve isolation by replacing global Mutex with thread-local Cell state to prevent cross-thread interference in concurrent tests. Update existing tests to utilize the new scoped helper. --- datafusion/optimizer/src/utils.rs | 81 ++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d8d45534a67f6..1cb2f986760e0 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,7 +23,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; #[cfg(test)] -use std::sync::Mutex; +use std::cell::Cell; use crate::analyzer::type_coercion::TypeCoercionRewriter; @@ -36,17 +36,38 @@ pub(crate) enum NullRestrictionEvalMode { } #[cfg(test)] -static NULL_RESTRICTION_EVAL_MODE: Mutex = - Mutex::new(NullRestrictionEvalMode::Auto); +thread_local! { + static NULL_RESTRICTION_EVAL_MODE: Cell = + const { Cell::new(NullRestrictionEvalMode::Auto) }; +} #[cfg(test)] pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { - *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() = mode; + NULL_RESTRICTION_EVAL_MODE.with(|eval_mode| eval_mode.set(mode)); } #[cfg(test)] fn null_restriction_eval_mode() -> NullRestrictionEvalMode { - *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() + NULL_RESTRICTION_EVAL_MODE.with(Cell::get) +} + +#[cfg(test)] +pub(crate) fn with_null_restriction_eval_mode_for_test( + mode: NullRestrictionEvalMode, + f: impl FnOnce() -> T, +) -> T { + struct NullRestrictionEvalModeReset(NullRestrictionEvalMode); + + impl Drop for NullRestrictionEvalModeReset { + fn drop(&mut self) { + set_null_restriction_eval_mode_for_test(self.0); + } + } + + let previous_mode = null_restriction_eval_mode(); + set_null_restriction_eval_mode_for_test(mode); + let _reset = NullRestrictionEvalModeReset(previous_mode); + f() } use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; @@ -445,18 +466,24 @@ mod tests { let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); let join_cols_of_predicate = predicate.column_refs(); - set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); - let auto_result = is_restrict_null_predicate( - predicate.clone(), - join_cols_of_predicate.iter().copied(), + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + ) + }, )?; - set_null_restriction_eval_mode_for_test( + let authoritative_result = with_null_restriction_eval_mode_for_test( NullRestrictionEvalMode::AuthoritativeOnly, - ); - let authoritative_result = is_restrict_null_predicate( - predicate.clone(), - join_cols_of_predicate.iter().copied(), + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + ) + }, )?; assert_eq!(auto_result, authoritative_result); @@ -470,17 +497,25 @@ mod tests { let predicate = binary_expr(col("a"), Operator::Gt, col("b")); let column_a = Column::from_name("a"); - set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); - let auto_result = - is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a))?; + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || { + is_restrict_null_predicate( + predicate.clone(), + std::iter::once(&column_a), + ) + }, + )?; - set_null_restriction_eval_mode_for_test( + let authoritative_only_result = with_null_restriction_eval_mode_for_test( NullRestrictionEvalMode::AuthoritativeOnly, - ); - let authoritative_only_result = - is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a))?; - - set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + || { + is_restrict_null_predicate( + predicate.clone(), + std::iter::once(&column_a), + ) + }, + )?; assert!(!auto_result, "{predicate}"); assert!(!authoritative_only_result, "{predicate}"); From ba03e05693aaee77173159db027113e82ef49495 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 18:28:07 +0800 Subject: [PATCH 43/63] Rename helper for scalar subquery clarity Update the function name to specify its relevance to scalar subquery cross-joins. Add an intent comment for better understanding of its purpose. Replace the old function call in join predicate handling for improved readability. --- datafusion/optimizer/src/push_down_filter.rs | 9 +++++++-- datafusion/optimizer/src/utils.rs | 18 ++++-------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a15e8f68b1863..45ff1d2603a2e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -308,7 +308,12 @@ fn is_derived_relation(plan: &LogicalPlan) -> bool { matches!(plan, LogicalPlan::SubqueryAlias(_)) } -fn should_keep_filter_above_cross_join(join: &Join, predicate: &Expr) -> bool { +// Keep post-join filters above certain scalar-subquery cross joins to preserve +// behavior for the window-over-scalar-subquery regression shape. +fn should_keep_filter_above_scalar_subquery_cross_join( + join: &Join, + predicate: &Expr, +) -> bool { if !join.on.is_empty() || join.filter.is_some() { return false; } @@ -471,7 +476,7 @@ fn push_down_all_join( } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join - && !should_keep_filter_above_cross_join(&join, &predicate) + && !should_keep_filter_above_scalar_subquery_cross_join(&join, &predicate) && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 1cb2f986760e0..a09538d9c2484 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -492,29 +492,19 @@ mod tests { } #[test] - fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode( - ) -> Result<()> { + fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode() -> Result<()> + { let predicate = binary_expr(col("a"), Operator::Gt, col("b")); let column_a = Column::from_name("a"); let auto_result = with_null_restriction_eval_mode_for_test( NullRestrictionEvalMode::Auto, - || { - is_restrict_null_predicate( - predicate.clone(), - std::iter::once(&column_a), - ) - }, + || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), )?; let authoritative_only_result = with_null_restriction_eval_mode_for_test( NullRestrictionEvalMode::AuthoritativeOnly, - || { - is_restrict_null_predicate( - predicate.clone(), - std::iter::once(&column_a), - ) - }, + || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), )?; assert!(!auto_result, "{predicate}"); From 2f658b490c0e311a1292de5a5d4ea232a5a657b1 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 18:31:32 +0800 Subject: [PATCH 44/63] Add maintainer comment in null_restriction.rs Clarify supported expression/operator families in the syntactic evaluator. Emphasize that returning None indicates deferral to authoritative evaluation, rather than "non-restricting." Ensure unsupported variants also return None for consistency. --- datafusion/optimizer/src/push_down_filter.rs | 4 ++-- datafusion/optimizer/src/utils/null_restriction.rs | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 45ff1d2603a2e..316f2c6396381 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -2535,8 +2535,8 @@ mod tests { } #[test] - fn scalar_subquery_cross_join_filter_is_treated_as_join_condition_candidate( - ) -> Result<()> { + fn scalar_subquery_cross_join_filter_is_treated_as_join_condition_candidate() + -> Result<()> { let predicate = col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")); assert!(can_evaluate_as_join_condition(&predicate)?); diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 28e32a26ed1da..6c3394a2ff11c 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -127,6 +127,15 @@ fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, ) -> Option { + // This evaluator intentionally supports a strict subset of expressions: + // aliases/columns/literals, boolean combinators (NOT/AND/OR), null checks + // (IS [NOT] NULL), BETWEEN, strict-null-preserving unary operators + // (CAST/TRY_CAST/NEGATIVE), LIKE/SIMILAR TO, and binary operators handled in + // `syntactic_binary_value`. + // + // Returning `None` means "defer to the authoritative evaluator" rather than + // "not null-restricting". Any unsupported expression variant must return + // `None` so callers can safely fall back to full expression evaluation. match expr { Expr::Alias(alias) => { syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) From e4ffcd15ff5bbd3b7e77ef7ea6eb5569382d8638 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 18:39:03 +0800 Subject: [PATCH 45/63] Reduce plan-shape sensitivity for scalar-subquery guard Broaden derived-relation detection to include projection wrappers over derived relations. Add regression tests to cover alias/projection shape changes and ensure mixed-side filters are preserved. Implement a panic-path robustness test to confirm that eval mode resets properly, even on closure panic using catch_unwind. --- datafusion/optimizer/src/push_down_filter.rs | 62 +++++++++++++++++++- datafusion/optimizer/src/utils.rs | 20 +++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 316f2c6396381..ced53d6ea6ace 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -305,7 +305,13 @@ fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool { } fn is_derived_relation(plan: &LogicalPlan) -> bool { - matches!(plan, LogicalPlan::SubqueryAlias(_)) + match plan { + LogicalPlan::SubqueryAlias(_) => true, + LogicalPlan::Projection(projection) => { + is_derived_relation(projection.input.as_ref()) + } + _ => false, + } } // Keep post-join filters above certain scalar-subquery cross joins to preserve @@ -2517,6 +2523,60 @@ mod tests { ) } + #[test] + fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join( + ) -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? + .alias("s")? + .project(vec![col("s.nation"), col("s.acctbal")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a").alias("acctbal")])? + .aggregate( + Vec::::new(), + vec![avg(col("acctbal")).alias("avg_acctbal")], + )? + .alias("__scalar_sq_1")? + .build()?; + + let window = Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::row_number::row_number_udwf(), + ), + vec![], + )) + .partition_by(vec![col("s.nation")]) + .order_by(vec![col("s.acctbal").sort(false, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? + .project(vec![col("s.nation"), col("s.acctbal")])? + .window(vec![window])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[row_number() PARTITION BY [s.nation] ORDER BY [s.acctbal DESC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + Projection: s.nation, s.acctbal + Filter: s.acctbal > __scalar_sq_1.avg_acctbal + Cross Join: + Projection: s.nation, s.acctbal + SubqueryAlias: s + Projection: test.a AS nation, test.b AS acctbal + TableScan: test + SubqueryAlias: __scalar_sq_1 + Aggregate: groupBy=[[]], aggr=[[avg(acctbal) AS avg_acctbal]] + Projection: test1.a AS acctbal + TableScan: test1 + " + ) + } + #[test] fn cross_join_builder_uses_inner_join_with_no_join_keys() -> Result<()> { let plan = LogicalPlanBuilder::from(test_table_scan()?) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index a09538d9c2484..5fb426e290156 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -245,6 +245,8 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { #[cfg(test)] mod tests { use super::*; + use std::panic::{AssertUnwindSafe, catch_unwind}; + use datafusion_expr::{ Operator, binary_expr, case, col, in_list, is_null, lit, when, }; @@ -512,4 +514,22 @@ mod tests { Ok(()) } + + #[test] + fn null_restriction_eval_mode_guard_restores_on_panic() { + set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + + let result = catch_unwind(AssertUnwindSafe(|| { + with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || panic!("intentional panic to verify test mode reset"), + ) + })); + + assert!(result.is_err()); + assert_eq!( + null_restriction_eval_mode(), + NullRestrictionEvalMode::Auto + ); + } } From a817b7dd604f740b650cf926dece0285727de31f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 18:46:27 +0800 Subject: [PATCH 46/63] Consolidate plan-wrapper traversal and optimizations Combine plan-wrapper traversal and cross-join shape detection. Shorten join-column replacement scan and share authoritative null-result decoding. Remove unused helpers and reorganize strict-null operator list behind a classifier helper. Public interfaces remain unchanged. --- .../core/benches/sql_planner_extended.rs | 11 +- datafusion/optimizer/src/push_down_filter.rs | 81 ++++++------ datafusion/optimizer/src/utils.rs | 62 +++++---- .../optimizer/src/utils/null_restriction.rs | 123 +++++++++--------- 4 files changed, 146 insertions(+), 131 deletions(-) diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index 767134bb5bafd..7db50b19bf566 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -357,8 +357,12 @@ fn build_non_case_left_join_df_with_push_down_filter( fn find_filter_predicates(plan: &LogicalPlan) -> Vec { match plan { LogicalPlan::Filter(filter) => split_conjunction_owned(filter.predicate.clone()), - LogicalPlan::Projection(projection) => find_filter_predicates(projection.input.as_ref()), - other => panic!("expected benchmark query plan to contain a Filter, found {other:?}"), + LogicalPlan::Projection(projection) => { + find_filter_predicates(projection.input.as_ref()) + } + other => { + panic!("expected benchmark query plan to contain a Filter, found {other:?}") + } } } @@ -375,7 +379,8 @@ fn assert_case_heavy_left_join_inference_candidates( for predicate in predicates { let column_refs = predicate.column_refs(); assert!( - column_refs.contains(&&left_join_key) || column_refs.contains(&&right_join_key), + column_refs.contains(&&left_join_key) + || column_refs.contains(&&right_join_key), "benchmark predicate should reference a join key: {predicate}" ); assert!( diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ced53d6ea6ace..c74fe5ad0458b 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,33 +285,39 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } -fn strip_aliases_and_projections(plan: &LogicalPlan) -> &LogicalPlan { +fn strip_plan_wrappers(plan: &LogicalPlan) -> (&LogicalPlan, bool) { match plan { LogicalPlan::SubqueryAlias(subquery_alias) => { - strip_aliases_and_projections(subquery_alias.input.as_ref()) + let (plan, _) = strip_plan_wrappers(subquery_alias.input.as_ref()); + (plan, true) } LogicalPlan::Projection(projection) => { - strip_aliases_and_projections(projection.input.as_ref()) + let (plan, is_derived_relation) = + strip_plan_wrappers(projection.input.as_ref()); + (plan, is_derived_relation) } - _ => plan, + _ => (plan, false), } } fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool { matches!( - strip_aliases_and_projections(plan), + strip_plan_wrappers(plan).0, LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty() ) } fn is_derived_relation(plan: &LogicalPlan) -> bool { - match plan { - LogicalPlan::SubqueryAlias(_) => true, - LogicalPlan::Projection(projection) => { - is_derived_relation(projection.input.as_ref()) - } - _ => false, - } + strip_plan_wrappers(plan).1 +} + +fn is_scalar_subquery_cross_join(join: &Join) -> bool { + join.on.is_empty() + && join.filter.is_none() + && ((is_scalar_aggregate_subquery(join.left.as_ref()) + && is_derived_relation(join.right.as_ref())) + || (is_scalar_aggregate_subquery(join.right.as_ref()) + && is_derived_relation(join.left.as_ref()))) } // Keep post-join filters above certain scalar-subquery cross joins to preserve @@ -320,19 +326,12 @@ fn should_keep_filter_above_scalar_subquery_cross_join( join: &Join, predicate: &Expr, ) -> bool { - if !join.on.is_empty() || join.filter.is_some() { + if !is_scalar_subquery_cross_join(join) { return false; } let mut checker = ColumnChecker::new(join.left.schema(), join.right.schema()); - let references_both_sides = - !checker.is_left_only(predicate) && !checker.is_right_only(predicate); - - references_both_sides - && ((is_scalar_aggregate_subquery(join.left.as_ref()) - && is_derived_relation(join.right.as_ref())) - || (is_scalar_aggregate_subquery(join.right.as_ref()) - && is_derived_relation(join.left.as_ref()))) + !checker.is_left_only(predicate) && !checker.is_right_only(predicate) } /// examine OR clause to see if any useful clauses can be extracted and push down. @@ -777,29 +776,21 @@ fn infer_join_predicates_impl< ) -> Result<()> { for predicate in input_predicates { let column_refs = predicate.column_refs(); - let mut join_col_replacements = Vec::new(); - let mut has_non_replaceable_refs = false; - - for &col in &column_refs { - let mut replacement = None; - - for (l, r) in join_col_keys.iter() { - if ENABLE_LEFT_TO_RIGHT && col == *l { - replacement = Some((col, *r)); - break; - } - if ENABLE_RIGHT_TO_LEFT && col == *r { - replacement = Some((col, *l)); - break; - } - } + let join_col_replacements: Vec<_> = column_refs + .iter() + .filter_map(|&col| { + join_col_keys.iter().find_map(|(l, r)| { + if ENABLE_LEFT_TO_RIGHT && col == *l { + Some((col, *r)) + } else if ENABLE_RIGHT_TO_LEFT && col == *r { + Some((col, *l)) + } else { + None + } + }) + }) + .collect(); - if let Some(replacement) = replacement { - join_col_replacements.push(replacement); - } else { - has_non_replaceable_refs = true; - } - } if join_col_replacements.is_empty() { continue; } @@ -807,7 +798,9 @@ fn infer_join_predicates_impl< // For non-inner joins, predicates that reference any non-replaceable // columns cannot be inferred on the other side. Skip the null-restriction // helper entirely in that common mixed-reference case. - if !inferred_predicates.is_inner_join && has_non_replaceable_refs { + if !inferred_predicates.is_inner_join + && join_col_replacements.len() != column_refs.len() + { continue; } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 5fb426e290156..2ef25882a2c0d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -130,13 +130,14 @@ pub fn is_restrict_null_predicate<'a>( // Collect join columns so they can be used in both the fast-path check and the // fallback evaluation path below. let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect(); + let column_refs = predicate.column_refs(); // Fast path: if the predicate references columns outside the join key set, // `evaluate_expr_with_null_column` would fail because the null schema only // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !null_restriction::predicate_uses_only_columns(&predicate, &join_cols) { + if !null_restriction::all_columns_allowed(&column_refs, &join_cols) { return Ok(false); } @@ -180,12 +181,10 @@ pub fn evaluates_to_null<'a>( return Ok(true); } - Ok( - match evaluate_expr_with_null_column(predicate, null_columns)? { - ColumnarValue::Array(_) => false, - ColumnarValue::Scalar(scalar) => scalar.is_null(), - }, - ) + Ok(authoritative_null_result(evaluate_expr_with_null_column( + predicate, + null_columns, + )?)? == AuthoritativeNullResult::AlwaysNull) } fn evaluate_expr_with_null_column<'a>( @@ -219,22 +218,41 @@ fn authoritative_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - Ok( - match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { - ColumnarValue::Array(array) => { - if array.len() == 1 { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) - } else { - false - } + Ok(authoritative_null_result(evaluate_expr_with_null_column( + predicate, + join_cols_of_predicate, + )?)? == AuthoritativeNullResult::NullRestricting) +} + +#[derive(Debug, PartialEq, Eq)] +enum AuthoritativeNullResult { + AlwaysNull, + NullRestricting, + Other, +} + +fn authoritative_null_result(value: ColumnarValue) -> Result { + Ok(match value { + ColumnarValue::Array(array) => { + if array.len() != 1 { + return Ok(AuthoritativeNullResult::Other); } - ColumnarValue::Scalar(scalar) => matches!( - scalar, - ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) - ), - }, - ) + + let boolean_array = as_boolean_array(&array)?; + if boolean_array.is_null(0) || !boolean_array.value(0) { + AuthoritativeNullResult::NullRestricting + } else { + AuthoritativeNullResult::Other + } + } + ColumnarValue::Scalar(scalar) if scalar.is_null() => { + AuthoritativeNullResult::AlwaysNull + } + ColumnarValue::Scalar( + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)), + ) => AuthoritativeNullResult::NullRestricting, + ColumnarValue::Scalar(_) => AuthoritativeNullResult::Other, + }) } fn coerce(expr: Expr, schema: &DFSchema) -> Result { diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 6c3394a2ff11c..3fc2dbab3f8d0 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -45,24 +45,15 @@ pub(super) fn syntactic_restrict_null_predicate( } } -pub(super) fn predicate_uses_only_columns( - predicate: &Expr, +pub(super) fn all_columns_allowed( + column_refs: &HashSet<&Column>, allowed_columns: &HashSet<&Column>, ) -> bool { - predicate - .column_refs() + column_refs .iter() .all(|column| allowed_columns.contains(*column)) } -fn contains_null( - values: impl IntoIterator>, -) -> bool { - values - .into_iter() - .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) -} - fn not(value: Option) -> Option { match value { Some(NullSubstitutionValue::Boolean(value)) => { @@ -120,7 +111,10 @@ fn null_check_value( fn null_if_contains_null( values: impl IntoIterator>, ) -> Option { - contains_null(values).then_some(NullSubstitutionValue::Null) + values + .into_iter() + .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) + .then_some(NullSubstitutionValue::Null) } fn syntactic_null_substitution_value( @@ -161,11 +155,18 @@ fn syntactic_null_substitution_value( syntactic_null_substitution_value(between.low.as_ref(), join_cols), syntactic_null_substitution_value(between.high.as_ref(), join_cols), ]), - Expr::Cast(cast) => strict_null_passthrough(cast.expr.as_ref(), join_cols), + Expr::Cast(cast) => { + syntactic_null_substitution_value(cast.expr.as_ref(), join_cols) + .filter(|value| matches!(value, NullSubstitutionValue::Null)) + } Expr::TryCast(try_cast) => { - strict_null_passthrough(try_cast.expr.as_ref(), join_cols) + syntactic_null_substitution_value(try_cast.expr.as_ref(), join_cols) + .filter(|value| matches!(value, NullSubstitutionValue::Null)) + } + Expr::Negative(expr) => { + syntactic_null_substitution_value(expr.as_ref(), join_cols) + .filter(|value| matches!(value, NullSubstitutionValue::Null)) } - Expr::Negative(expr) => strict_null_passthrough(expr.as_ref(), join_cols), Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ syntactic_null_substitution_value(like.expr.as_ref(), join_cols), syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), @@ -204,15 +205,49 @@ fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionVal } } -fn strict_null_passthrough( - expr: &Expr, - join_cols: &HashSet<&Column>, -) -> Option { +fn is_strict_null_binary_op(op: Operator) -> bool { matches!( - syntactic_null_substitution_value(expr, join_cols), - Some(NullSubstitutionValue::Null) + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + | Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + | Operator::StringConcat + | Operator::AtArrow + | Operator::ArrowAt + | Operator::Arrow + | Operator::LongArrow + | Operator::HashArrow + | Operator::HashLongArrow + | Operator::AtAt + | Operator::IntegerDivide + | Operator::HashMinus + | Operator::AtQuestion + | Operator::Question + | Operator::QuestionAnd + | Operator::QuestionPipe + | Operator::Colon ) - .then_some(NullSubstitutionValue::Null) } fn syntactic_binary_value( @@ -225,45 +260,9 @@ fn syntactic_binary_value( match binary_expr.op { Operator::And => binary_boolean_value(left, right, false), Operator::Or => binary_boolean_value(left, right, true), - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq - | Operator::Plus - | Operator::Minus - | Operator::Multiply - | Operator::Divide - | Operator::Modulo - | Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch - | Operator::LikeMatch - | Operator::ILikeMatch - | Operator::NotLikeMatch - | Operator::NotILikeMatch - | Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft - | Operator::StringConcat - | Operator::AtArrow - | Operator::ArrowAt - | Operator::Arrow - | Operator::LongArrow - | Operator::HashArrow - | Operator::HashLongArrow - | Operator::AtAt - | Operator::IntegerDivide - | Operator::HashMinus - | Operator::AtQuestion - | Operator::Question - | Operator::QuestionAnd - | Operator::QuestionPipe - | Operator::Colon => null_if_contains_null([left, right]), Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => None, + op => is_strict_null_binary_op(op) + .then(|| null_if_contains_null([left, right])) + .flatten(), } } From 4c143adfa1f3617608f9e76b148bad37aa1654a2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 18:57:51 +0800 Subject: [PATCH 47/63] Separate semantics for evaluates_to_null and authoritative_restrict_null_predicate Restore evaluates_to_null behavior for general expressions without boolean-downcasting arrays. Fix scalar_subquery_with_non_strong_project regression. Update authoritative_restrict_null_predicate to handle predicate results directly, treating scalar NULL as null-restricting, resolving the CASE ... ELSE NULL test failure. Maintain non-join-column fast path functionality. --- datafusion/optimizer/src/utils.rs | 61 ++++++++++++------------------- 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 2ef25882a2c0d..dd9b5034f457d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -181,10 +181,12 @@ pub fn evaluates_to_null<'a>( return Ok(true); } - Ok(authoritative_null_result(evaluate_expr_with_null_column( - predicate, - null_columns, - )?)? == AuthoritativeNullResult::AlwaysNull) + Ok( + match evaluate_expr_with_null_column(predicate, null_columns)? { + ColumnarValue::Array(_) => false, + ColumnarValue::Scalar(scalar) => scalar.is_null(), + }, + ) } fn evaluate_expr_with_null_column<'a>( @@ -218,41 +220,24 @@ fn authoritative_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - Ok(authoritative_null_result(evaluate_expr_with_null_column( - predicate, - join_cols_of_predicate, - )?)? == AuthoritativeNullResult::NullRestricting) -} - -#[derive(Debug, PartialEq, Eq)] -enum AuthoritativeNullResult { - AlwaysNull, - NullRestricting, - Other, -} - -fn authoritative_null_result(value: ColumnarValue) -> Result { - Ok(match value { - ColumnarValue::Array(array) => { - if array.len() != 1 { - return Ok(AuthoritativeNullResult::Other); + Ok( + match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } } - - let boolean_array = as_boolean_array(&array)?; - if boolean_array.is_null(0) || !boolean_array.value(0) { - AuthoritativeNullResult::NullRestricting - } else { - AuthoritativeNullResult::Other - } - } - ColumnarValue::Scalar(scalar) if scalar.is_null() => { - AuthoritativeNullResult::AlwaysNull - } - ColumnarValue::Scalar( - ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)), - ) => AuthoritativeNullResult::NullRestricting, - ColumnarValue::Scalar(_) => AuthoritativeNullResult::Other, - }) + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) + | ScalarValue::Boolean(Some(false)) + | ScalarValue::Null + ), + }, + ) } fn coerce(expr: Expr, schema: &DFSchema) -> Result { From 18b754174381ad775945a3f18624cf72b5bc339f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 19:17:01 +0800 Subject: [PATCH 48/63] cargo fmt --- datafusion/optimizer/src/push_down_filter.rs | 4 ++-- datafusion/optimizer/src/utils.rs | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c74fe5ad0458b..a9053e78e7751 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -2517,8 +2517,8 @@ mod tests { } #[test] - fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join( - ) -> Result<()> { + fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join() + -> Result<()> { let left = LogicalPlanBuilder::from(test_table_scan()?) .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? .alias("s")? diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index dd9b5034f457d..329271a067ee8 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -530,9 +530,6 @@ mod tests { })); assert!(result.is_err()); - assert_eq!( - null_restriction_eval_mode(), - NullRestrictionEvalMode::Auto - ); + assert_eq!(null_restriction_eval_mode(), NullRestrictionEvalMode::Auto); } } From 106e96346fa0bdfeac1f401794b1c462198f2abc Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 19:33:20 +0800 Subject: [PATCH 49/63] Refactor join input handling and null evaluation Consolidate repeated join-input wrapper inspections into a single JoinInputShape classifier. Hoist scalar-subquery cross-join shape check out of the predicate loop and unify repeated left/right predicate bucketing. Remove temporary Vec in join-column replacement inference and narrow test-only null-restriction mode support into its own helper module. Share column-subset check path and extract helper for authoritative null-evaluation results. Reduce repetition in syntactic null-restriction evaluator by factoring strict-null-preserving unary cases. --- datafusion/optimizer/src/push_down_filter.rs | 161 +++++++++++------- datafusion/optimizer/src/utils.rs | 148 ++++++++-------- .../optimizer/src/utils/null_restriction.rs | 40 ++--- 3 files changed, 198 insertions(+), 151 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a9053e78e7751..df9d02ff52619 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,55 +285,83 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } -fn strip_plan_wrappers(plan: &LogicalPlan) -> (&LogicalPlan, bool) { +#[derive(Clone, Copy)] +struct JoinInputShape<'a> { + base_plan: &'a LogicalPlan, + is_derived_relation: bool, +} + +fn classify_join_input(plan: &LogicalPlan) -> JoinInputShape<'_> { match plan { LogicalPlan::SubqueryAlias(subquery_alias) => { - let (plan, _) = strip_plan_wrappers(subquery_alias.input.as_ref()); - (plan, true) + let JoinInputShape { base_plan, .. } = + classify_join_input(subquery_alias.input.as_ref()); + JoinInputShape { + base_plan, + is_derived_relation: true, + } } LogicalPlan::Projection(projection) => { - let (plan, is_derived_relation) = - strip_plan_wrappers(projection.input.as_ref()); - (plan, is_derived_relation) + let shape = classify_join_input(projection.input.as_ref()); + JoinInputShape { + is_derived_relation: shape.is_derived_relation, + ..shape + } } - _ => (plan, false), + _ => JoinInputShape { + base_plan: plan, + is_derived_relation: false, + }, } } -fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool { +fn is_scalar_aggregate_subquery(shape: JoinInputShape<'_>) -> bool { matches!( - strip_plan_wrappers(plan).0, + shape.base_plan, LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty() ) } -fn is_derived_relation(plan: &LogicalPlan) -> bool { - strip_plan_wrappers(plan).1 -} - fn is_scalar_subquery_cross_join(join: &Join) -> bool { + let left_shape = classify_join_input(join.left.as_ref()); + let right_shape = classify_join_input(join.right.as_ref()); join.on.is_empty() && join.filter.is_none() - && ((is_scalar_aggregate_subquery(join.left.as_ref()) - && is_derived_relation(join.right.as_ref())) - || (is_scalar_aggregate_subquery(join.right.as_ref()) - && is_derived_relation(join.left.as_ref()))) + && ((is_scalar_aggregate_subquery(left_shape) && right_shape.is_derived_relation) + || (is_scalar_aggregate_subquery(right_shape) + && left_shape.is_derived_relation)) } // Keep post-join filters above certain scalar-subquery cross joins to preserve // behavior for the window-over-scalar-subquery regression shape. fn should_keep_filter_above_scalar_subquery_cross_join( - join: &Join, + mut checker: ColumnChecker<'_>, predicate: &Expr, ) -> bool { - if !is_scalar_subquery_cross_join(join) { - return false; - } - - let mut checker = ColumnChecker::new(join.left.schema(), join.right.schema()); !checker.is_left_only(predicate) && !checker.is_right_only(predicate) } +enum PredicateDestination { + Left, + Right, + Keep, +} + +fn classify_predicate_destination( + checker: &mut ColumnChecker<'_>, + predicate: &Expr, + allow_left: bool, + allow_right: bool, +) -> PredicateDestination { + if allow_left && checker.is_left_only(predicate) { + PredicateDestination::Left + } else if allow_right && checker.is_right_only(predicate) { + PredicateDestination::Right + } else { + PredicateDestination::Keep + } +} + /// examine OR clause to see if any useful clauses can be extracted and push down. /// extract at least one qual from each sub clauses of OR clause, then form the quals /// to new OR clause as predicate. @@ -475,29 +503,44 @@ fn push_down_all_join( let mut keep_predicates = vec![]; let mut join_conditions = vec![]; let mut checker = ColumnChecker::new(left_schema, right_schema); + let keep_mixed_scalar_subquery_filters = + is_inner_join && is_scalar_subquery_cross_join(&join); for predicate in predicates { - if left_preserved && checker.is_left_only(&predicate) { - left_push.push(predicate); - } else if right_preserved && checker.is_right_only(&predicate) { - right_push.push(predicate); - } else if is_inner_join - && !should_keep_filter_above_scalar_subquery_cross_join(&join, &predicate) - && can_evaluate_as_join_condition(&predicate)? - { - // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate - // and convert to the join on condition - join_conditions.push(predicate); - } else { - keep_predicates.push(predicate); + match classify_predicate_destination( + &mut checker, + &predicate, + left_preserved, + right_preserved, + ) { + PredicateDestination::Left => left_push.push(predicate), + PredicateDestination::Right => right_push.push(predicate), + PredicateDestination::Keep => { + let should_keep_above_join = keep_mixed_scalar_subquery_filters + && should_keep_filter_above_scalar_subquery_cross_join( + ColumnChecker::new(left_schema, right_schema), + &predicate, + ); + + if is_inner_join + && !should_keep_above_join + && can_evaluate_as_join_condition(&predicate)? + { + // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate + // and convert to the join on condition + join_conditions.push(predicate); + } else { + keep_predicates.push(predicate); + } + } } } // Push predicates inferred from the join expression for predicate in inferred_join_predicates { - if checker.is_left_only(&predicate) { - left_push.push(predicate); - } else if checker.is_right_only(&predicate) { - right_push.push(predicate); + match classify_predicate_destination(&mut checker, &predicate, true, true) { + PredicateDestination::Left => left_push.push(predicate), + PredicateDestination::Right => right_push.push(predicate), + PredicateDestination::Keep => {} } } @@ -506,12 +549,15 @@ fn push_down_all_join( if !on_filter.is_empty() { for on in on_filter { - if on_left_preserved && checker.is_left_only(&on) { - left_push.push(on) - } else if on_right_preserved && checker.is_right_only(&on) { - right_push.push(on) - } else { - on_filter_join_conditions.push(on) + match classify_predicate_destination( + &mut checker, + &on, + on_left_preserved, + on_right_preserved, + ) { + PredicateDestination::Left => left_push.push(on), + PredicateDestination::Right => right_push.push(on), + PredicateDestination::Keep => on_filter_join_conditions.push(on), } } } @@ -776,10 +822,11 @@ fn infer_join_predicates_impl< ) -> Result<()> { for predicate in input_predicates { let column_refs = predicate.column_refs(); - let join_col_replacements: Vec<_> = column_refs + let mut saw_non_replaceable_ref = false; + let join_cols_to_replace = column_refs .iter() .filter_map(|&col| { - join_col_keys.iter().find_map(|(l, r)| { + let replacement = join_col_keys.iter().find_map(|(l, r)| { if ENABLE_LEFT_TO_RIGHT && col == *l { Some((col, *r)) } else if ENABLE_RIGHT_TO_LEFT && col == *r { @@ -787,24 +834,18 @@ fn infer_join_predicates_impl< } else { None } - }) + }); + saw_non_replaceable_ref |= replacement.is_none(); + replacement }) - .collect(); - - if join_col_replacements.is_empty() { - continue; - } + .collect::>(); - // For non-inner joins, predicates that reference any non-replaceable - // columns cannot be inferred on the other side. Skip the null-restriction - // helper entirely in that common mixed-reference case. - if !inferred_predicates.is_inner_join - && join_col_replacements.len() != column_refs.len() + if join_cols_to_replace.is_empty() + || (!inferred_predicates.is_inner_join && saw_non_replaceable_ref) { continue; } - let join_cols_to_replace = join_col_replacements.into_iter().collect(); inferred_predicates .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 329271a067ee8..0dafe6d342d53 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -22,53 +22,7 @@ mod null_restriction; use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; -#[cfg(test)] -use std::cell::Cell; - use crate::analyzer::type_coercion::TypeCoercionRewriter; - -/// Null restriction evaluation mode for optimizer tests. -#[cfg(test)] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub(crate) enum NullRestrictionEvalMode { - Auto, - AuthoritativeOnly, -} - -#[cfg(test)] -thread_local! { - static NULL_RESTRICTION_EVAL_MODE: Cell = - const { Cell::new(NullRestrictionEvalMode::Auto) }; -} - -#[cfg(test)] -pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { - NULL_RESTRICTION_EVAL_MODE.with(|eval_mode| eval_mode.set(mode)); -} - -#[cfg(test)] -fn null_restriction_eval_mode() -> NullRestrictionEvalMode { - NULL_RESTRICTION_EVAL_MODE.with(Cell::get) -} - -#[cfg(test)] -pub(crate) fn with_null_restriction_eval_mode_for_test( - mode: NullRestrictionEvalMode, - f: impl FnOnce() -> T, -) -> T { - struct NullRestrictionEvalModeReset(NullRestrictionEvalMode); - - impl Drop for NullRestrictionEvalModeReset { - fn drop(&mut self) { - set_null_restriction_eval_mode_for_test(self.0); - } - } - - let previous_mode = null_restriction_eval_mode(); - set_null_restriction_eval_mode_for_test(mode); - let _reset = NullRestrictionEvalModeReset(previous_mode); - f() -} use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; @@ -84,15 +38,15 @@ use log::{debug, trace}; /// as it was initially placed here and then moved elsewhere. pub use datafusion_expr::expr_rewriter::NamePreserver; +#[cfg(test)] +use self::test_eval_mode::{ + NullRestrictionEvalMode, null_restriction_eval_mode, + set_null_restriction_eval_mode_for_test, with_null_restriction_eval_mode_for_test, +}; + /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { - let column_refs = expr.column_refs(); - // note can't use HashSet::intersect because of different types (owned vs References) - schema_cols - .iter() - .filter(|c| column_refs.contains(c)) - .count() - == column_refs.len() + column_refs_all_in(&expr.column_refs(), |column| schema_cols.contains(column)) } pub(crate) fn replace_qualified_name( @@ -116,6 +70,13 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { trace!("{description}::\n{}\n", plan.display_indent_schema()); } +pub(super) fn column_refs_all_in<'a>( + column_refs: &HashSet<&'a Column>, + mut contains: impl FnMut(&Column) -> bool, +) -> bool { + column_refs.iter().all(|column| contains(column)) +} + /// Determine whether a predicate can restrict NULLs. e.g. /// `c0 > 8` return true; /// `c0 IS NULL` return false. @@ -137,7 +98,7 @@ pub fn is_restrict_null_predicate<'a>( // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !null_restriction::all_columns_allowed(&column_refs, &join_cols) { + if !column_refs_all_in(&column_refs, |column| join_cols.contains(&column)) { return Ok(false); } @@ -220,24 +181,8 @@ fn authoritative_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - Ok( - match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { - ColumnarValue::Array(array) => { - if array.len() == 1 { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) - } else { - false - } - } - ColumnarValue::Scalar(scalar) => matches!( - scalar, - ScalarValue::Boolean(None) - | ScalarValue::Boolean(Some(false)) - | ScalarValue::Null - ), - }, - ) + evaluate_expr_with_null_column(predicate, join_cols_of_predicate) + .and_then(is_false_or_null_boolean_result) } fn coerce(expr: Expr, schema: &DFSchema) -> Result { @@ -245,6 +190,65 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { expr.rewrite(&mut expr_rewrite).data() } +fn is_false_or_null_boolean_result(result: ColumnarValue) -> Result { + Ok(match result { + ColumnarValue::Array(array) if array.len() == 1 => { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } + ColumnarValue::Array(_) => false, + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) + | ScalarValue::Boolean(Some(false)) + | ScalarValue::Null + ), + }) +} + +#[cfg(test)] +mod test_eval_mode { + use std::cell::Cell; + + /// Null restriction evaluation mode for optimizer tests. + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + pub(crate) enum NullRestrictionEvalMode { + Auto, + AuthoritativeOnly, + } + + thread_local! { + static NULL_RESTRICTION_EVAL_MODE: Cell = + const { Cell::new(NullRestrictionEvalMode::Auto) }; + } + + pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { + NULL_RESTRICTION_EVAL_MODE.with(|eval_mode| eval_mode.set(mode)); + } + + pub(crate) fn null_restriction_eval_mode() -> NullRestrictionEvalMode { + NULL_RESTRICTION_EVAL_MODE.with(Cell::get) + } + + pub(crate) fn with_null_restriction_eval_mode_for_test( + mode: NullRestrictionEvalMode, + f: impl FnOnce() -> T, + ) -> T { + struct NullRestrictionEvalModeReset(NullRestrictionEvalMode); + + impl Drop for NullRestrictionEvalModeReset { + fn drop(&mut self) { + set_null_restriction_eval_mode_for_test(self.0); + } + } + + let previous_mode = null_restriction_eval_mode(); + set_null_restriction_eval_mode_for_test(mode); + let _reset = NullRestrictionEvalModeReset(previous_mode); + f() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 3fc2dbab3f8d0..2e13bd7d9fd0c 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -32,6 +32,12 @@ enum NullSubstitutionValue { Boolean(bool), } +impl NullSubstitutionValue { + fn is_null(self) -> bool { + matches!(self, Self::Null) + } +} + pub(super) fn syntactic_restrict_null_predicate( predicate: &Expr, join_cols: &HashSet<&Column>, @@ -45,15 +51,6 @@ pub(super) fn syntactic_restrict_null_predicate( } } -pub(super) fn all_columns_allowed( - column_refs: &HashSet<&Column>, - allowed_columns: &HashSet<&Column>, -) -> bool { - column_refs - .iter() - .all(|column| allowed_columns.contains(*column)) -} - fn not(value: Option) -> Option { match value { Some(NullSubstitutionValue::Boolean(value)) => { @@ -117,6 +114,12 @@ fn null_if_contains_null( .then_some(NullSubstitutionValue::Null) } +fn strict_null_only( + value: Option, +) -> Option { + value.filter(|value| value.is_null()) +} + fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, @@ -155,17 +158,16 @@ fn syntactic_null_substitution_value( syntactic_null_substitution_value(between.low.as_ref(), join_cols), syntactic_null_substitution_value(between.high.as_ref(), join_cols), ]), - Expr::Cast(cast) => { - syntactic_null_substitution_value(cast.expr.as_ref(), join_cols) - .filter(|value| matches!(value, NullSubstitutionValue::Null)) - } - Expr::TryCast(try_cast) => { - syntactic_null_substitution_value(try_cast.expr.as_ref(), join_cols) - .filter(|value| matches!(value, NullSubstitutionValue::Null)) - } + Expr::Cast(cast) => strict_null_only(syntactic_null_substitution_value( + cast.expr.as_ref(), + join_cols, + )), + Expr::TryCast(try_cast) => strict_null_only(syntactic_null_substitution_value( + try_cast.expr.as_ref(), + join_cols, + )), Expr::Negative(expr) => { - syntactic_null_substitution_value(expr.as_ref(), join_cols) - .filter(|value| matches!(value, NullSubstitutionValue::Null)) + strict_null_only(syntactic_null_substitution_value(expr.as_ref(), join_cols)) } Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ syntactic_null_substitution_value(like.expr.as_ref(), join_cols), From 5effa4c9d466ddd7e82f6018a44fa8a665cfc15c Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 19:49:55 +0800 Subject: [PATCH 50/63] Refactor push_down_filter for simplicity and efficiency Eliminate JoinInputShape/PredicateDestination layers and integrate scalar-subquery handling into the main join-pushdown flow. Revert join-inference replacement logic to an explicit loop. Remove unnecessary helper functions and simplify the null-restriction evaluator for cleaner code. --- datafusion/optimizer/src/push_down_filter.rs | 173 ++++++------------ datafusion/optimizer/src/utils.rs | 46 ++--- .../optimizer/src/utils/null_restriction.rs | 14 +- 3 files changed, 76 insertions(+), 157 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index df9d02ff52619..f7ccc7dea951c 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,81 +285,30 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } -#[derive(Clone, Copy)] -struct JoinInputShape<'a> { - base_plan: &'a LogicalPlan, - is_derived_relation: bool, -} - -fn classify_join_input(plan: &LogicalPlan) -> JoinInputShape<'_> { +fn classify_join_input(plan: &LogicalPlan) -> (bool, bool) { match plan { LogicalPlan::SubqueryAlias(subquery_alias) => { - let JoinInputShape { base_plan, .. } = + let (is_scalar_aggregate, _) = classify_join_input(subquery_alias.input.as_ref()); - JoinInputShape { - base_plan, - is_derived_relation: true, - } + (is_scalar_aggregate, true) } LogicalPlan::Projection(projection) => { - let shape = classify_join_input(projection.input.as_ref()); - JoinInputShape { - is_derived_relation: shape.is_derived_relation, - ..shape - } + classify_join_input(projection.input.as_ref()) } - _ => JoinInputShape { - base_plan: plan, - is_derived_relation: false, - }, + LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false), + _ => (false, false), } } -fn is_scalar_aggregate_subquery(shape: JoinInputShape<'_>) -> bool { - matches!( - shape.base_plan, - LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty() - ) -} - fn is_scalar_subquery_cross_join(join: &Join) -> bool { - let left_shape = classify_join_input(join.left.as_ref()); - let right_shape = classify_join_input(join.right.as_ref()); + let (left_scalar_aggregate, left_is_derived_relation) = + classify_join_input(join.left.as_ref()); + let (right_scalar_aggregate, right_is_derived_relation) = + classify_join_input(join.right.as_ref()); join.on.is_empty() && join.filter.is_none() - && ((is_scalar_aggregate_subquery(left_shape) && right_shape.is_derived_relation) - || (is_scalar_aggregate_subquery(right_shape) - && left_shape.is_derived_relation)) -} - -// Keep post-join filters above certain scalar-subquery cross joins to preserve -// behavior for the window-over-scalar-subquery regression shape. -fn should_keep_filter_above_scalar_subquery_cross_join( - mut checker: ColumnChecker<'_>, - predicate: &Expr, -) -> bool { - !checker.is_left_only(predicate) && !checker.is_right_only(predicate) -} - -enum PredicateDestination { - Left, - Right, - Keep, -} - -fn classify_predicate_destination( - checker: &mut ColumnChecker<'_>, - predicate: &Expr, - allow_left: bool, - allow_right: bool, -) -> PredicateDestination { - if allow_left && checker.is_left_only(predicate) { - PredicateDestination::Left - } else if allow_right && checker.is_right_only(predicate) { - PredicateDestination::Right - } else { - PredicateDestination::Keep - } + && ((left_scalar_aggregate && right_is_derived_relation) + || (right_scalar_aggregate && left_is_derived_relation)) } /// examine OR clause to see if any useful clauses can be extracted and push down. @@ -506,41 +455,28 @@ fn push_down_all_join( let keep_mixed_scalar_subquery_filters = is_inner_join && is_scalar_subquery_cross_join(&join); for predicate in predicates { - match classify_predicate_destination( - &mut checker, - &predicate, - left_preserved, - right_preserved, - ) { - PredicateDestination::Left => left_push.push(predicate), - PredicateDestination::Right => right_push.push(predicate), - PredicateDestination::Keep => { - let should_keep_above_join = keep_mixed_scalar_subquery_filters - && should_keep_filter_above_scalar_subquery_cross_join( - ColumnChecker::new(left_schema, right_schema), - &predicate, - ); - - if is_inner_join - && !should_keep_above_join - && can_evaluate_as_join_condition(&predicate)? - { - // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate - // and convert to the join on condition - join_conditions.push(predicate); - } else { - keep_predicates.push(predicate); - } - } + if left_preserved && checker.is_left_only(&predicate) { + left_push.push(predicate); + } else if right_preserved && checker.is_right_only(&predicate) { + right_push.push(predicate); + } else if is_inner_join + && !keep_mixed_scalar_subquery_filters + && can_evaluate_as_join_condition(&predicate)? + { + // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate + // and convert to the join on condition + join_conditions.push(predicate); + } else { + keep_predicates.push(predicate); } } // Push predicates inferred from the join expression for predicate in inferred_join_predicates { - match classify_predicate_destination(&mut checker, &predicate, true, true) { - PredicateDestination::Left => left_push.push(predicate), - PredicateDestination::Right => right_push.push(predicate), - PredicateDestination::Keep => {} + if checker.is_left_only(&predicate) { + left_push.push(predicate); + } else if checker.is_right_only(&predicate) { + right_push.push(predicate); } } @@ -549,15 +485,12 @@ fn push_down_all_join( if !on_filter.is_empty() { for on in on_filter { - match classify_predicate_destination( - &mut checker, - &on, - on_left_preserved, - on_right_preserved, - ) { - PredicateDestination::Left => left_push.push(on), - PredicateDestination::Right => right_push.push(on), - PredicateDestination::Keep => on_filter_join_conditions.push(on), + if on_left_preserved && checker.is_left_only(&on) { + left_push.push(on) + } else if on_right_preserved && checker.is_right_only(&on) { + right_push.push(on) + } else { + on_filter_join_conditions.push(on) } } } @@ -821,24 +754,26 @@ fn infer_join_predicates_impl< inferred_predicates: &mut InferredPredicates, ) -> Result<()> { for predicate in input_predicates { - let column_refs = predicate.column_refs(); + let mut join_cols_to_replace = HashMap::new(); let mut saw_non_replaceable_ref = false; - let join_cols_to_replace = column_refs - .iter() - .filter_map(|&col| { - let replacement = join_col_keys.iter().find_map(|(l, r)| { - if ENABLE_LEFT_TO_RIGHT && col == *l { - Some((col, *r)) - } else if ENABLE_RIGHT_TO_LEFT && col == *r { - Some((col, *l)) - } else { - None - } - }); - saw_non_replaceable_ref |= replacement.is_none(); - replacement - }) - .collect::>(); + + for &col in &predicate.column_refs() { + let replacement = join_col_keys.iter().find_map(|(l, r)| { + if ENABLE_LEFT_TO_RIGHT && col == *l { + Some((col, *r)) + } else if ENABLE_RIGHT_TO_LEFT && col == *r { + Some((col, *l)) + } else { + None + } + }); + + if let Some((source, target)) = replacement { + join_cols_to_replace.insert(source, target); + } else { + saw_non_replaceable_ref = true; + } + } if join_cols_to_replace.is_empty() || (!inferred_predicates.is_inner_join && saw_non_replaceable_ref) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 0dafe6d342d53..ae8ec74606ff2 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -46,7 +46,9 @@ use self::test_eval_mode::{ /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { - column_refs_all_in(&expr.column_refs(), |column| schema_cols.contains(column)) + expr.column_refs() + .iter() + .all(|column| schema_cols.contains(*column)) } pub(crate) fn replace_qualified_name( @@ -70,13 +72,6 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { trace!("{description}::\n{}\n", plan.display_indent_schema()); } -pub(super) fn column_refs_all_in<'a>( - column_refs: &HashSet<&'a Column>, - mut contains: impl FnMut(&Column) -> bool, -) -> bool { - column_refs.iter().all(|column| contains(column)) -} - /// Determine whether a predicate can restrict NULLs. e.g. /// `c0 > 8` return true; /// `c0 IS NULL` return false. @@ -98,7 +93,7 @@ pub fn is_restrict_null_predicate<'a>( // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !column_refs_all_in(&column_refs, |column| join_cols.contains(&column)) { + if !column_refs.iter().all(|column| join_cols.contains(*column)) { return Ok(false); } @@ -181,8 +176,21 @@ fn authoritative_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - evaluate_expr_with_null_column(predicate, join_cols_of_predicate) - .and_then(is_false_or_null_boolean_result) + Ok( + match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + ColumnarValue::Array(array) if array.len() == 1 => { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } + ColumnarValue::Array(_) => false, + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) + | ScalarValue::Boolean(Some(false)) + | ScalarValue::Null + ), + }, + ) } fn coerce(expr: Expr, schema: &DFSchema) -> Result { @@ -190,22 +198,6 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { expr.rewrite(&mut expr_rewrite).data() } -fn is_false_or_null_boolean_result(result: ColumnarValue) -> Result { - Ok(match result { - ColumnarValue::Array(array) if array.len() == 1 => { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) - } - ColumnarValue::Array(_) => false, - ColumnarValue::Scalar(scalar) => matches!( - scalar, - ScalarValue::Boolean(None) - | ScalarValue::Boolean(Some(false)) - | ScalarValue::Null - ), - }) -} - #[cfg(test)] mod test_eval_mode { use std::cell::Cell; diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 2e13bd7d9fd0c..5f6b9eef794ea 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -32,21 +32,13 @@ enum NullSubstitutionValue { Boolean(bool), } -impl NullSubstitutionValue { - fn is_null(self) -> bool { - matches!(self, Self::Null) - } -} - pub(super) fn syntactic_restrict_null_predicate( predicate: &Expr, join_cols: &HashSet<&Column>, ) -> Option { match syntactic_null_substitution_value(predicate, join_cols) { - Some(NullSubstitutionValue::Boolean(true)) => Some(false), - Some(NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null) => { - Some(true) - } + Some(NullSubstitutionValue::Boolean(value)) => Some(!value), + Some(NullSubstitutionValue::Null) => Some(true), Some(NullSubstitutionValue::NonNull) | None => None, } } @@ -117,7 +109,7 @@ fn null_if_contains_null( fn strict_null_only( value: Option, ) -> Option { - value.filter(|value| value.is_null()) + value.filter(|value| matches!(value, NullSubstitutionValue::Null)) } fn syntactic_null_substitution_value( From 0563f6c493f8fe34f4b0518a1272255f41ecb2ce Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 19:57:11 +0800 Subject: [PATCH 51/63] Refactor utility layers and improve code compactness Flatten single-use helper layers and tighten small utility helpers. In push_down_filter, fold the join-input classifier and bind predicate.column_refs() once for join inference. Shorten test-only import surface in utils.rs and compact mapping helpers in null_restriction.rs without changing semantics. --- datafusion/optimizer/src/push_down_filter.rs | 31 +++++++++---------- datafusion/optimizer/src/utils.rs | 5 +-- .../optimizer/src/utils/null_restriction.rs | 18 +++++------ 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f7ccc7dea951c..871dbeb215dd4 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,26 +285,22 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } -fn classify_join_input(plan: &LogicalPlan) -> (bool, bool) { - match plan { - LogicalPlan::SubqueryAlias(subquery_alias) => { - let (is_scalar_aggregate, _) = - classify_join_input(subquery_alias.input.as_ref()); - (is_scalar_aggregate, true) - } - LogicalPlan::Projection(projection) => { - classify_join_input(projection.input.as_ref()) +fn is_scalar_subquery_cross_join(join: &Join) -> bool { + fn classify(plan: &LogicalPlan) -> (bool, bool) { + match plan { + LogicalPlan::SubqueryAlias(subquery_alias) => { + let (is_scalar_aggregate, _) = classify(subquery_alias.input.as_ref()); + (is_scalar_aggregate, true) + } + LogicalPlan::Projection(projection) => classify(projection.input.as_ref()), + LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false), + _ => (false, false), } - LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false), - _ => (false, false), } -} -fn is_scalar_subquery_cross_join(join: &Join) -> bool { - let (left_scalar_aggregate, left_is_derived_relation) = - classify_join_input(join.left.as_ref()); + let (left_scalar_aggregate, left_is_derived_relation) = classify(join.left.as_ref()); let (right_scalar_aggregate, right_is_derived_relation) = - classify_join_input(join.right.as_ref()); + classify(join.right.as_ref()); join.on.is_empty() && join.filter.is_none() && ((left_scalar_aggregate && right_is_derived_relation) @@ -754,10 +750,11 @@ fn infer_join_predicates_impl< inferred_predicates: &mut InferredPredicates, ) -> Result<()> { for predicate in input_predicates { + let column_refs = predicate.column_refs(); let mut join_cols_to_replace = HashMap::new(); let mut saw_non_replaceable_ref = false; - for &col in &predicate.column_refs() { + for &col in &column_refs { let replacement = join_col_keys.iter().find_map(|(l, r)| { if ENABLE_LEFT_TO_RIGHT && col == *l { Some((col, *r)) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index ae8ec74606ff2..1741f6beae025 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -39,10 +39,7 @@ use log::{debug, trace}; pub use datafusion_expr::expr_rewriter::NamePreserver; #[cfg(test)] -use self::test_eval_mode::{ - NullRestrictionEvalMode, null_restriction_eval_mode, - set_null_restriction_eval_mode_for_test, with_null_restriction_eval_mode_for_test, -}; +use self::test_eval_mode::*; /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 5f6b9eef794ea..4c5eecd560edd 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -36,11 +36,13 @@ pub(super) fn syntactic_restrict_null_predicate( predicate: &Expr, join_cols: &HashSet<&Column>, ) -> Option { - match syntactic_null_substitution_value(predicate, join_cols) { - Some(NullSubstitutionValue::Boolean(value)) => Some(!value), - Some(NullSubstitutionValue::Null) => Some(true), - Some(NullSubstitutionValue::NonNull) | None => None, - } + syntactic_null_substitution_value(predicate, join_cols).and_then( + |value| match value { + NullSubstitutionValue::Boolean(value) => Some(!value), + NullSubstitutionValue::Null => Some(true), + NullSubstitutionValue::NonNull => None, + }, + ) } fn not(value: Option) -> Option { @@ -49,7 +51,7 @@ fn not(value: Option) -> Option { Some(NullSubstitutionValue::Boolean(!value)) } Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), - Some(NullSubstitutionValue::NonNull) | None => None, + _ => None, } } @@ -90,9 +92,7 @@ fn null_check_value( Some(NullSubstitutionValue::Null) => { Some(NullSubstitutionValue::Boolean(!is_not_null)) } - Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { - Some(NullSubstitutionValue::Boolean(is_not_null)) - } + Some(_) => Some(NullSubstitutionValue::Boolean(is_not_null)), None => None, } } From 8e1a2f2e2e6be9f8b722a79222d1a20ce1fd3260 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 20:04:44 +0800 Subject: [PATCH 52/63] Refactor utility layers and simplify null handling Tighten scalar-subquery guard, remove redundant empty-loop scaffolding, and share more of the null-evaluation flow. Simplify small null-restriction helpers and match arms. --- datafusion/optimizer/src/push_down_filter.rs | 100 +++++++++++------- datafusion/optimizer/src/utils.rs | 44 +++++--- .../optimizer/src/utils/null_restriction.rs | 48 ++++----- 3 files changed, 112 insertions(+), 80 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 871dbeb215dd4..0b1495045c631 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -286,21 +286,26 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { } fn is_scalar_subquery_cross_join(join: &Join) -> bool { - fn classify(plan: &LogicalPlan) -> (bool, bool) { + fn is_scalar_aggregate_or_derived_relation(plan: &LogicalPlan) -> (bool, bool) { match plan { LogicalPlan::SubqueryAlias(subquery_alias) => { - let (is_scalar_aggregate, _) = classify(subquery_alias.input.as_ref()); + let (is_scalar_aggregate, _) = is_scalar_aggregate_or_derived_relation( + subquery_alias.input.as_ref(), + ); (is_scalar_aggregate, true) } - LogicalPlan::Projection(projection) => classify(projection.input.as_ref()), + LogicalPlan::Projection(projection) => { + is_scalar_aggregate_or_derived_relation(projection.input.as_ref()) + } LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false), _ => (false, false), } } - let (left_scalar_aggregate, left_is_derived_relation) = classify(join.left.as_ref()); + let (left_scalar_aggregate, left_is_derived_relation) = + is_scalar_aggregate_or_derived_relation(join.left.as_ref()); let (right_scalar_aggregate, right_is_derived_relation) = - classify(join.right.as_ref()); + is_scalar_aggregate_or_derived_relation(join.right.as_ref()); join.on.is_empty() && join.filter.is_none() && ((left_scalar_aggregate && right_is_derived_relation) @@ -451,9 +456,12 @@ fn push_down_all_join( let keep_mixed_scalar_subquery_filters = is_inner_join && is_scalar_subquery_cross_join(&join); for predicate in predicates { - if left_preserved && checker.is_left_only(&predicate) { + let left_only = left_preserved && checker.is_left_only(&predicate); + let right_only = + !left_only && right_preserved && checker.is_right_only(&predicate); + if left_only { left_push.push(predicate); - } else if right_preserved && checker.is_right_only(&predicate) { + } else if right_only { right_push.push(predicate); } else if is_inner_join && !keep_mixed_scalar_subquery_filters @@ -479,43 +487,63 @@ fn push_down_all_join( let mut on_filter_join_conditions = vec![]; let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type); - if !on_filter.is_empty() { - for on in on_filter { - if on_left_preserved && checker.is_left_only(&on) { - left_push.push(on) - } else if on_right_preserved && checker.is_right_only(&on) { - right_push.push(on) - } else { - on_filter_join_conditions.push(on) - } + for on in on_filter { + if on_left_preserved && checker.is_left_only(&on) { + left_push.push(on) + } else if on_right_preserved && checker.is_right_only(&on) { + right_push.push(on) + } else { + on_filter_join_conditions.push(on) } } // Extract from OR clause, generate new predicates for both side of join if possible. // We only track the unpushable predicates above. - if left_preserved { - left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema)); - left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema)); - } - if right_preserved { - right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema)); - right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); - } + let extend_or_clauses = + |target: &mut Vec, filters: &[Expr], schema: &DFSchema, preserved| { + if preserved { + target.extend(extract_or_clauses_for_join(filters, schema)); + } + }; + extend_or_clauses( + &mut left_push, + &keep_predicates, + left_schema, + left_preserved, + ); + extend_or_clauses( + &mut left_push, + &join_conditions, + left_schema, + left_preserved, + ); + extend_or_clauses( + &mut right_push, + &keep_predicates, + right_schema, + right_preserved, + ); + extend_or_clauses( + &mut right_push, + &join_conditions, + right_schema, + right_preserved, + ); // For predicates from join filter, we should check with if a join side is preserved // in term of join filtering. - if on_left_preserved { - left_push.extend(extract_or_clauses_for_join( - &on_filter_join_conditions, - left_schema, - )); - } - if on_right_preserved { - right_push.extend(extract_or_clauses_for_join( - &on_filter_join_conditions, - right_schema, - )); - } + extend_or_clauses( + &mut left_push, + &on_filter_join_conditions, + left_schema, + on_left_preserved, + ); + extend_or_clauses( + &mut right_push, + &on_filter_join_conditions, + right_schema, + on_right_preserved, + ); if let Some(predicate) = conjunction(left_push) { join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 1741f6beae025..1fa09a181f579 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -38,9 +38,6 @@ use log::{debug, trace}; /// as it was initially placed here and then moved elsewhere. pub use datafusion_expr::expr_rewriter::NamePreserver; -#[cfg(test)] -use self::test_eval_mode::*; - /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { expr.column_refs() @@ -76,7 +73,7 @@ pub fn is_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - if matches!(predicate, Expr::Column(_)) { + if is_column_expr(&predicate) { return Ok(true); } @@ -96,8 +93,8 @@ pub fn is_restrict_null_predicate<'a>( #[cfg(test)] if matches!( - null_restriction_eval_mode(), - NullRestrictionEvalMode::AuthoritativeOnly + test_eval_mode::null_restriction_eval_mode(), + test_eval_mode::NullRestrictionEvalMode::AuthoritativeOnly ) { return authoritative_restrict_null_predicate(predicate, join_cols); } @@ -130,16 +127,16 @@ pub fn evaluates_to_null<'a>( predicate: Expr, null_columns: impl IntoIterator, ) -> Result { - if matches!(predicate, Expr::Column(_)) { + if is_column_expr(&predicate) { return Ok(true); } - Ok( - match evaluate_expr_with_null_column(predicate, null_columns)? { + evaluate_with_null_columns(predicate, null_columns, |result| { + Ok(match result { ColumnarValue::Array(_) => false, ColumnarValue::Scalar(scalar) => scalar.is_null(), - }, - ) + }) + }) } fn evaluate_expr_with_null_column<'a>( @@ -173,8 +170,8 @@ fn authoritative_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - Ok( - match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + evaluate_with_null_columns(predicate, join_cols_of_predicate, |result| { + Ok(match result { ColumnarValue::Array(array) if array.len() == 1 => { let boolean_array = as_boolean_array(&array)?; boolean_array.is_null(0) || !boolean_array.value(0) @@ -186,8 +183,8 @@ fn authoritative_restrict_null_predicate<'a>( | ScalarValue::Boolean(Some(false)) | ScalarValue::Null ), - }, - ) + }) + }) } fn coerce(expr: Expr, schema: &DFSchema) -> Result { @@ -195,6 +192,18 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { expr.rewrite(&mut expr_rewrite).data() } +fn is_column_expr(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_)) +} + +fn evaluate_with_null_columns<'a, T>( + predicate: Expr, + null_columns: impl IntoIterator, + f: impl FnOnce(ColumnarValue) -> Result, +) -> Result { + f(evaluate_expr_with_null_column(predicate, null_columns)?) +} + #[cfg(test)] mod test_eval_mode { use std::cell::Cell; @@ -243,6 +252,11 @@ mod tests { use super::*; use std::panic::{AssertUnwindSafe, catch_unwind}; + use crate::utils::test_eval_mode::{ + NullRestrictionEvalMode, null_restriction_eval_mode, + set_null_restriction_eval_mode_for_test, + with_null_restriction_eval_mode_for_test, + }; use datafusion_expr::{ Operator, binary_expr, case, col, in_list, is_null, lit, when, }; diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 4c5eecd560edd..4a65d7812da97 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -36,13 +36,11 @@ pub(super) fn syntactic_restrict_null_predicate( predicate: &Expr, join_cols: &HashSet<&Column>, ) -> Option { - syntactic_null_substitution_value(predicate, join_cols).and_then( - |value| match value { - NullSubstitutionValue::Boolean(value) => Some(!value), - NullSubstitutionValue::Null => Some(true), - NullSubstitutionValue::NonNull => None, - }, - ) + match syntactic_null_substitution_value(predicate, join_cols) { + Some(NullSubstitutionValue::Boolean(value)) => Some(!value), + Some(NullSubstitutionValue::Null) => Some(true), + _ => None, + } } fn not(value: Option) -> Option { @@ -88,13 +86,10 @@ fn null_check_value( value: Option, is_not_null: bool, ) -> Option { - match value { - Some(NullSubstitutionValue::Null) => { - Some(NullSubstitutionValue::Boolean(!is_not_null)) - } - Some(_) => Some(NullSubstitutionValue::Boolean(is_not_null)), - None => None, - } + value.map(|value| match value { + NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(!is_not_null), + _ => NullSubstitutionValue::Boolean(is_not_null), + }) } fn null_if_contains_null( @@ -106,12 +101,6 @@ fn null_if_contains_null( .then_some(NullSubstitutionValue::Null) } -fn strict_null_only( - value: Option, -) -> Option { - value.filter(|value| matches!(value, NullSubstitutionValue::Null)) -} - fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, @@ -150,16 +139,17 @@ fn syntactic_null_substitution_value( syntactic_null_substitution_value(between.low.as_ref(), join_cols), syntactic_null_substitution_value(between.high.as_ref(), join_cols), ]), - Expr::Cast(cast) => strict_null_only(syntactic_null_substitution_value( - cast.expr.as_ref(), - join_cols, - )), - Expr::TryCast(try_cast) => strict_null_only(syntactic_null_substitution_value( - try_cast.expr.as_ref(), - join_cols, - )), + Expr::Cast(cast) => { + syntactic_null_substitution_value(cast.expr.as_ref(), join_cols) + .filter(|value| matches!(value, NullSubstitutionValue::Null)) + } + Expr::TryCast(try_cast) => { + syntactic_null_substitution_value(try_cast.expr.as_ref(), join_cols) + .filter(|value| matches!(value, NullSubstitutionValue::Null)) + } Expr::Negative(expr) => { - strict_null_only(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + syntactic_null_substitution_value(expr.as_ref(), join_cols) + .filter(|value| matches!(value, NullSubstitutionValue::Null)) } Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ syntactic_null_substitution_value(like.expr.as_ref(), join_cols), From a3bcb57870932b49423e314fbe339d3f3701ff93 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 20:10:31 +0800 Subject: [PATCH 53/63] Revert to 787d0a4d3 --- datafusion/optimizer/src/push_down_filter.rs | 119 +++++++----------- datafusion/optimizer/src/utils.rs | 47 +++---- .../optimizer/src/utils/null_restriction.rs | 42 ++++--- 3 files changed, 91 insertions(+), 117 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 0b1495045c631..f7ccc7dea951c 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,27 +285,26 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } -fn is_scalar_subquery_cross_join(join: &Join) -> bool { - fn is_scalar_aggregate_or_derived_relation(plan: &LogicalPlan) -> (bool, bool) { - match plan { - LogicalPlan::SubqueryAlias(subquery_alias) => { - let (is_scalar_aggregate, _) = is_scalar_aggregate_or_derived_relation( - subquery_alias.input.as_ref(), - ); - (is_scalar_aggregate, true) - } - LogicalPlan::Projection(projection) => { - is_scalar_aggregate_or_derived_relation(projection.input.as_ref()) - } - LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false), - _ => (false, false), +fn classify_join_input(plan: &LogicalPlan) -> (bool, bool) { + match plan { + LogicalPlan::SubqueryAlias(subquery_alias) => { + let (is_scalar_aggregate, _) = + classify_join_input(subquery_alias.input.as_ref()); + (is_scalar_aggregate, true) + } + LogicalPlan::Projection(projection) => { + classify_join_input(projection.input.as_ref()) } + LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false), + _ => (false, false), } +} +fn is_scalar_subquery_cross_join(join: &Join) -> bool { let (left_scalar_aggregate, left_is_derived_relation) = - is_scalar_aggregate_or_derived_relation(join.left.as_ref()); + classify_join_input(join.left.as_ref()); let (right_scalar_aggregate, right_is_derived_relation) = - is_scalar_aggregate_or_derived_relation(join.right.as_ref()); + classify_join_input(join.right.as_ref()); join.on.is_empty() && join.filter.is_none() && ((left_scalar_aggregate && right_is_derived_relation) @@ -456,12 +455,9 @@ fn push_down_all_join( let keep_mixed_scalar_subquery_filters = is_inner_join && is_scalar_subquery_cross_join(&join); for predicate in predicates { - let left_only = left_preserved && checker.is_left_only(&predicate); - let right_only = - !left_only && right_preserved && checker.is_right_only(&predicate); - if left_only { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_only { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join && !keep_mixed_scalar_subquery_filters @@ -487,63 +483,43 @@ fn push_down_all_join( let mut on_filter_join_conditions = vec![]; let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type); - for on in on_filter { - if on_left_preserved && checker.is_left_only(&on) { - left_push.push(on) - } else if on_right_preserved && checker.is_right_only(&on) { - right_push.push(on) - } else { - on_filter_join_conditions.push(on) + if !on_filter.is_empty() { + for on in on_filter { + if on_left_preserved && checker.is_left_only(&on) { + left_push.push(on) + } else if on_right_preserved && checker.is_right_only(&on) { + right_push.push(on) + } else { + on_filter_join_conditions.push(on) + } } } // Extract from OR clause, generate new predicates for both side of join if possible. // We only track the unpushable predicates above. - let extend_or_clauses = - |target: &mut Vec, filters: &[Expr], schema: &DFSchema, preserved| { - if preserved { - target.extend(extract_or_clauses_for_join(filters, schema)); - } - }; - extend_or_clauses( - &mut left_push, - &keep_predicates, - left_schema, - left_preserved, - ); - extend_or_clauses( - &mut left_push, - &join_conditions, - left_schema, - left_preserved, - ); - extend_or_clauses( - &mut right_push, - &keep_predicates, - right_schema, - right_preserved, - ); - extend_or_clauses( - &mut right_push, - &join_conditions, - right_schema, - right_preserved, - ); + if left_preserved { + left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema)); + left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema)); + } + if right_preserved { + right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema)); + right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); + } // For predicates from join filter, we should check with if a join side is preserved // in term of join filtering. - extend_or_clauses( - &mut left_push, - &on_filter_join_conditions, - left_schema, - on_left_preserved, - ); - extend_or_clauses( - &mut right_push, - &on_filter_join_conditions, - right_schema, - on_right_preserved, - ); + if on_left_preserved { + left_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + left_schema, + )); + } + if on_right_preserved { + right_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + right_schema, + )); + } if let Some(predicate) = conjunction(left_push) { join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); @@ -778,11 +754,10 @@ fn infer_join_predicates_impl< inferred_predicates: &mut InferredPredicates, ) -> Result<()> { for predicate in input_predicates { - let column_refs = predicate.column_refs(); let mut join_cols_to_replace = HashMap::new(); let mut saw_non_replaceable_ref = false; - for &col in &column_refs { + for &col in &predicate.column_refs() { let replacement = join_col_keys.iter().find_map(|(l, r)| { if ENABLE_LEFT_TO_RIGHT && col == *l { Some((col, *r)) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 1fa09a181f579..ae8ec74606ff2 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -38,6 +38,12 @@ use log::{debug, trace}; /// as it was initially placed here and then moved elsewhere. pub use datafusion_expr::expr_rewriter::NamePreserver; +#[cfg(test)] +use self::test_eval_mode::{ + NullRestrictionEvalMode, null_restriction_eval_mode, + set_null_restriction_eval_mode_for_test, with_null_restriction_eval_mode_for_test, +}; + /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { expr.column_refs() @@ -73,7 +79,7 @@ pub fn is_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - if is_column_expr(&predicate) { + if matches!(predicate, Expr::Column(_)) { return Ok(true); } @@ -93,8 +99,8 @@ pub fn is_restrict_null_predicate<'a>( #[cfg(test)] if matches!( - test_eval_mode::null_restriction_eval_mode(), - test_eval_mode::NullRestrictionEvalMode::AuthoritativeOnly + null_restriction_eval_mode(), + NullRestrictionEvalMode::AuthoritativeOnly ) { return authoritative_restrict_null_predicate(predicate, join_cols); } @@ -127,16 +133,16 @@ pub fn evaluates_to_null<'a>( predicate: Expr, null_columns: impl IntoIterator, ) -> Result { - if is_column_expr(&predicate) { + if matches!(predicate, Expr::Column(_)) { return Ok(true); } - evaluate_with_null_columns(predicate, null_columns, |result| { - Ok(match result { + Ok( + match evaluate_expr_with_null_column(predicate, null_columns)? { ColumnarValue::Array(_) => false, ColumnarValue::Scalar(scalar) => scalar.is_null(), - }) - }) + }, + ) } fn evaluate_expr_with_null_column<'a>( @@ -170,8 +176,8 @@ fn authoritative_restrict_null_predicate<'a>( predicate: Expr, join_cols_of_predicate: impl IntoIterator, ) -> Result { - evaluate_with_null_columns(predicate, join_cols_of_predicate, |result| { - Ok(match result { + Ok( + match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { ColumnarValue::Array(array) if array.len() == 1 => { let boolean_array = as_boolean_array(&array)?; boolean_array.is_null(0) || !boolean_array.value(0) @@ -183,8 +189,8 @@ fn authoritative_restrict_null_predicate<'a>( | ScalarValue::Boolean(Some(false)) | ScalarValue::Null ), - }) - }) + }, + ) } fn coerce(expr: Expr, schema: &DFSchema) -> Result { @@ -192,18 +198,6 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { expr.rewrite(&mut expr_rewrite).data() } -fn is_column_expr(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_)) -} - -fn evaluate_with_null_columns<'a, T>( - predicate: Expr, - null_columns: impl IntoIterator, - f: impl FnOnce(ColumnarValue) -> Result, -) -> Result { - f(evaluate_expr_with_null_column(predicate, null_columns)?) -} - #[cfg(test)] mod test_eval_mode { use std::cell::Cell; @@ -252,11 +246,6 @@ mod tests { use super::*; use std::panic::{AssertUnwindSafe, catch_unwind}; - use crate::utils::test_eval_mode::{ - NullRestrictionEvalMode, null_restriction_eval_mode, - set_null_restriction_eval_mode_for_test, - with_null_restriction_eval_mode_for_test, - }; use datafusion_expr::{ Operator, binary_expr, case, col, in_list, is_null, lit, when, }; diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 4a65d7812da97..5f6b9eef794ea 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -39,7 +39,7 @@ pub(super) fn syntactic_restrict_null_predicate( match syntactic_null_substitution_value(predicate, join_cols) { Some(NullSubstitutionValue::Boolean(value)) => Some(!value), Some(NullSubstitutionValue::Null) => Some(true), - _ => None, + Some(NullSubstitutionValue::NonNull) | None => None, } } @@ -49,7 +49,7 @@ fn not(value: Option) -> Option { Some(NullSubstitutionValue::Boolean(!value)) } Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), - _ => None, + Some(NullSubstitutionValue::NonNull) | None => None, } } @@ -86,10 +86,15 @@ fn null_check_value( value: Option, is_not_null: bool, ) -> Option { - value.map(|value| match value { - NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(!is_not_null), - _ => NullSubstitutionValue::Boolean(is_not_null), - }) + match value { + Some(NullSubstitutionValue::Null) => { + Some(NullSubstitutionValue::Boolean(!is_not_null)) + } + Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { + Some(NullSubstitutionValue::Boolean(is_not_null)) + } + None => None, + } } fn null_if_contains_null( @@ -101,6 +106,12 @@ fn null_if_contains_null( .then_some(NullSubstitutionValue::Null) } +fn strict_null_only( + value: Option, +) -> Option { + value.filter(|value| matches!(value, NullSubstitutionValue::Null)) +} + fn syntactic_null_substitution_value( expr: &Expr, join_cols: &HashSet<&Column>, @@ -139,17 +150,16 @@ fn syntactic_null_substitution_value( syntactic_null_substitution_value(between.low.as_ref(), join_cols), syntactic_null_substitution_value(between.high.as_ref(), join_cols), ]), - Expr::Cast(cast) => { - syntactic_null_substitution_value(cast.expr.as_ref(), join_cols) - .filter(|value| matches!(value, NullSubstitutionValue::Null)) - } - Expr::TryCast(try_cast) => { - syntactic_null_substitution_value(try_cast.expr.as_ref(), join_cols) - .filter(|value| matches!(value, NullSubstitutionValue::Null)) - } + Expr::Cast(cast) => strict_null_only(syntactic_null_substitution_value( + cast.expr.as_ref(), + join_cols, + )), + Expr::TryCast(try_cast) => strict_null_only(syntactic_null_substitution_value( + try_cast.expr.as_ref(), + join_cols, + )), Expr::Negative(expr) => { - syntactic_null_substitution_value(expr.as_ref(), join_cols) - .filter(|value| matches!(value, NullSubstitutionValue::Null)) + strict_null_only(syntactic_null_substitution_value(expr.as_ref(), join_cols)) } Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ syntactic_null_substitution_value(like.expr.as_ref(), join_cols), From 37bcc07559366af95a5b70ea4f550620215957c2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 21:38:03 +0800 Subject: [PATCH 54/63] Refactor tests for clarity and reusability Extract shared helpers in push_down_filter_regressions.rs and push_down_filter.rs to reduce code duplication. Consolidate optimizer-delta test assertions and create specific plan builders for common expressions. Add a utility in utils.rs to evaluate predicates under different null-restriction modes, streamlining mode-comparison tests and enhancing maintainability. --- .../tests/sql/push_down_filter_regressions.rs | 45 +++---- datafusion/optimizer/src/push_down_filter.rs | 110 ++++++++---------- datafusion/optimizer/src/utils.rs | 59 +++++----- 3 files changed, 90 insertions(+), 124 deletions(-) diff --git a/datafusion/core/tests/sql/push_down_filter_regressions.rs b/datafusion/core/tests/sql/push_down_filter_regressions.rs index a1ff8293c97a1..5ad53f33c8b98 100644 --- a/datafusion/core/tests/sql/push_down_filter_regressions.rs +++ b/datafusion/core/tests/sql/push_down_filter_regressions.rs @@ -33,6 +33,9 @@ const WINDOW_SCALAR_SUBQUERY_SQL: &str = r#" ) "#; +const WINDOW_SCALAR_SUBQUERY_EXPECTED: &[&str] = + &["+----+", "| rn |", "+----+", "| 1 |", "+----+"]; + fn sqllogictest_style_ctx(push_down_filter_enabled: bool) -> SessionContext { let ctx = SessionContext::new_with_config(SessionConfig::new().with_target_partitions(4)); @@ -56,30 +59,20 @@ async fn capture_window_scalar_subquery_plans( )) } -#[tokio::test] -async fn window_scalar_subquery_regression() -> Result<()> { - let ctx = SessionContext::new(); +async fn assert_window_scalar_subquery(ctx: SessionContext) -> Result<()> { let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?; - - assert_batches_eq!( - &["+----+", "| rn |", "+----+", "| 1 |", "+----+",], - &results - ); - + assert_batches_eq!(WINDOW_SCALAR_SUBQUERY_EXPECTED, &results); Ok(()) } #[tokio::test] -async fn window_scalar_subquery_sqllogictest_style_regression() -> Result<()> { - let ctx = sqllogictest_style_ctx(true); - let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?; - - assert_batches_eq!( - &["+----+", "| rn |", "+----+", "| 1 |", "+----+",], - &results - ); +async fn window_scalar_subquery_regression() -> Result<()> { + assert_window_scalar_subquery(SessionContext::new()).await +} - Ok(()) +#[tokio::test] +async fn window_scalar_subquery_sqllogictest_style_regression() -> Result<()> { + assert_window_scalar_subquery(sqllogictest_style_ctx(true)).await } #[tokio::test] @@ -212,28 +205,18 @@ async fn window_scalar_subquery_optimizer_delta() -> Result<()> { let (disabled_optimized, disabled_physical) = capture_window_scalar_subquery_plans(false).await?; + assert_eq!(enabled_optimized, disabled_optimized); + assert_eq!(enabled_physical, disabled_physical); + assert!( enabled_optimized .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") ); assert!(enabled_optimized.contains("Cross Join:")); - assert!( - disabled_optimized - .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") - ); - assert!(disabled_optimized.contains("Cross Join:")); - assert!( enabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") ); assert!(enabled_physical.contains("CrossJoinExec")); - assert!( - disabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") - ); - assert!(disabled_physical.contains("CrossJoinExec")); - - assert_eq!(enabled_optimized, disabled_optimized); - assert_eq!(enabled_physical, disabled_physical); Ok(()) } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f7ccc7dea951c..32e502242b01b 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1522,6 +1522,53 @@ mod tests { use super::*; + fn scalar_subquery_right_plan() -> Result { + LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a").alias("acctbal")])? + .aggregate( + Vec::::new(), + vec![avg(col("acctbal")).alias("avg_acctbal")], + )? + .alias("__scalar_sq_1")? + .build() + } + + fn row_number_window_expr() -> Expr { + Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::row_number::row_number_udwf(), + ), + vec![], + )) + .partition_by(vec![col("s.nation")]) + .order_by(vec![col("s.acctbal").sort(false, true)]) + .build() + .unwrap() + } + + fn window_over_scalar_subquery_cross_join_plan( + with_project_wrapper: bool, + ) -> Result { + let left = { + let builder = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? + .alias("s")?; + let builder = if with_project_wrapper { + builder.project(vec![col("s.nation"), col("s.acctbal")])? + } else { + builder + }; + builder.build()? + }; + + LogicalPlanBuilder::from(left) + .cross_join(scalar_subquery_right_plan()?)? + .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? + .project(vec![col("s.nation"), col("s.acctbal")])? + .window(vec![row_number_window_expr()])? + .build() + } + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} macro_rules! assert_optimized_plan_equal { @@ -2443,36 +2490,7 @@ mod tests { #[test] fn window_over_scalar_subquery_cross_join_keeps_filter_above_join() -> Result<()> { - let left = LogicalPlanBuilder::from(test_table_scan()?) - .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? - .alias("s")? - .build()?; - let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) - .project(vec![col("a").alias("acctbal")])? - .aggregate( - Vec::::new(), - vec![avg(col("acctbal")).alias("avg_acctbal")], - )? - .alias("__scalar_sq_1")? - .build()?; - - let window = Expr::from(WindowFunction::new( - WindowFunctionDefinition::WindowUDF( - datafusion_functions_window::row_number::row_number_udwf(), - ), - vec![], - )) - .partition_by(vec![col("s.nation")]) - .order_by(vec![col("s.acctbal").sort(false, true)]) - .build() - .unwrap(); - - let plan = LogicalPlanBuilder::from(left) - .cross_join(right)? - .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? - .project(vec![col("s.nation"), col("s.acctbal")])? - .window(vec![window])? - .build()?; + let plan = window_over_scalar_subquery_cross_join_plan(false)?; assert_optimized_plan_equal!( plan, @@ -2495,37 +2513,7 @@ mod tests { #[test] fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join() -> Result<()> { - let left = LogicalPlanBuilder::from(test_table_scan()?) - .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? - .alias("s")? - .project(vec![col("s.nation"), col("s.acctbal")])? - .build()?; - let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) - .project(vec![col("a").alias("acctbal")])? - .aggregate( - Vec::::new(), - vec![avg(col("acctbal")).alias("avg_acctbal")], - )? - .alias("__scalar_sq_1")? - .build()?; - - let window = Expr::from(WindowFunction::new( - WindowFunctionDefinition::WindowUDF( - datafusion_functions_window::row_number::row_number_udwf(), - ), - vec![], - )) - .partition_by(vec![col("s.nation")]) - .order_by(vec![col("s.acctbal").sort(false, true)]) - .build() - .unwrap(); - - let plan = LogicalPlanBuilder::from(left) - .cross_join(right)? - .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? - .project(vec![col("s.nation"), col("s.acctbal")])? - .window(vec![window])? - .build()?; + let plan = window_over_scalar_subquery_cross_join_plan(true)?; assert_optimized_plan_equal!( plan, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index ae8ec74606ff2..a56d03a0c02ba 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -250,6 +250,23 @@ mod tests { Operator, binary_expr, case, col, in_list, is_null, lit, when, }; + fn restrict_null_predicate_in_modes( + predicate: Expr, + join_cols: &[Column], + ) -> Result<(bool, bool)> { + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || is_restrict_null_predicate(predicate.clone(), join_cols.iter()), + )?; + + let authoritative_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || is_restrict_null_predicate(predicate.clone(), join_cols.iter()), + )?; + + Ok((auto_result, authoritative_result)) + } + #[test] fn expr_is_restrict_null_predicate() -> Result<()> { let test_cases = vec![ @@ -465,27 +482,13 @@ mod tests { #[test] fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> { let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); - let join_cols_of_predicate = predicate.column_refs(); - - let auto_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::Auto, - || { - is_restrict_null_predicate( - predicate.clone(), - join_cols_of_predicate.iter().copied(), - ) - }, - )?; - - let authoritative_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::AuthoritativeOnly, - || { - is_restrict_null_predicate( - predicate.clone(), - join_cols_of_predicate.iter().copied(), - ) - }, - )?; + let join_cols_of_predicate = predicate + .column_refs() + .into_iter() + .cloned() + .collect::>(); + let (auto_result, authoritative_result) = + restrict_null_predicate_in_modes(predicate, &join_cols_of_predicate)?; assert_eq!(auto_result, authoritative_result); @@ -496,17 +499,9 @@ mod tests { fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode() -> Result<()> { let predicate = binary_expr(col("a"), Operator::Gt, col("b")); - let column_a = Column::from_name("a"); - - let auto_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::Auto, - || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), - )?; - - let authoritative_only_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::AuthoritativeOnly, - || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), - )?; + let join_cols = vec![Column::from_name("a")]; + let (auto_result, authoritative_only_result) = + restrict_null_predicate_in_modes(predicate.clone(), &join_cols)?; assert!(!auto_result, "{predicate}"); assert!(!authoritative_only_result, "{predicate}"); From c936b6a3f90e07f716b2e7df9ea9b9558d4cdd4e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 27 Mar 2026 21:57:29 +0800 Subject: [PATCH 55/63] Refactor is_restrict_null_predicate in utils.rs Restore is_restrict_null_predicate to its original behavior by removing the broader syntactic null-restriction path and its associated test scaffolding. Maintain early rejection for the push_down_filter caller while eliminating extra tree-walk work on common paths. Confirmed changes by running tests and formatting. --- datafusion/optimizer/src/utils.rs | 217 +-------------- .../optimizer/src/utils/null_restriction.rs | 262 ------------------ 2 files changed, 5 insertions(+), 474 deletions(-) delete mode 100644 datafusion/optimizer/src/utils/null_restriction.rs diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index a56d03a0c02ba..2520cbbb5fa61 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,8 +17,6 @@ //! Utility functions leveraged by the query optimizer rules -mod null_restriction; - use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; @@ -38,12 +36,6 @@ use log::{debug, trace}; /// as it was initially placed here and then moved elsewhere. pub use datafusion_expr::expr_rewriter::NamePreserver; -#[cfg(test)] -use self::test_eval_mode::{ - NullRestrictionEvalMode, null_restriction_eval_mode, - set_null_restriction_eval_mode_for_test, with_null_restriction_eval_mode_for_test, -}; - /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { expr.column_refs() @@ -86,40 +78,17 @@ pub fn is_restrict_null_predicate<'a>( // Collect join columns so they can be used in both the fast-path check and the // fallback evaluation path below. let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect(); - let column_refs = predicate.column_refs(); - // Fast path: if the predicate references columns outside the join key set, // `evaluate_expr_with_null_column` would fail because the null schema only // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !column_refs.iter().all(|column| join_cols.contains(*column)) { - return Ok(false); - } - - #[cfg(test)] - if matches!( - null_restriction_eval_mode(), - NullRestrictionEvalMode::AuthoritativeOnly - ) { - return authoritative_restrict_null_predicate(predicate, join_cols); - } - - if let Some(is_restricting) = - null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols) + if !predicate + .column_refs() + .iter() + .all(|column| join_cols.contains(*column)) { - #[cfg(debug_assertions)] - { - let authoritative = authoritative_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - )?; - debug_assert_eq!( - is_restricting, authoritative, - "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" - ); - } - return Ok(is_restricting); + return Ok(false); } authoritative_restrict_null_predicate(predicate, join_cols) @@ -198,75 +167,13 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { expr.rewrite(&mut expr_rewrite).data() } -#[cfg(test)] -mod test_eval_mode { - use std::cell::Cell; - - /// Null restriction evaluation mode for optimizer tests. - #[derive(Copy, Clone, Debug, PartialEq, Eq)] - pub(crate) enum NullRestrictionEvalMode { - Auto, - AuthoritativeOnly, - } - - thread_local! { - static NULL_RESTRICTION_EVAL_MODE: Cell = - const { Cell::new(NullRestrictionEvalMode::Auto) }; - } - - pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { - NULL_RESTRICTION_EVAL_MODE.with(|eval_mode| eval_mode.set(mode)); - } - - pub(crate) fn null_restriction_eval_mode() -> NullRestrictionEvalMode { - NULL_RESTRICTION_EVAL_MODE.with(Cell::get) - } - - pub(crate) fn with_null_restriction_eval_mode_for_test( - mode: NullRestrictionEvalMode, - f: impl FnOnce() -> T, - ) -> T { - struct NullRestrictionEvalModeReset(NullRestrictionEvalMode); - - impl Drop for NullRestrictionEvalModeReset { - fn drop(&mut self) { - set_null_restriction_eval_mode_for_test(self.0); - } - } - - let previous_mode = null_restriction_eval_mode(); - set_null_restriction_eval_mode_for_test(mode); - let _reset = NullRestrictionEvalModeReset(previous_mode); - f() - } -} - #[cfg(test)] mod tests { use super::*; - use std::panic::{AssertUnwindSafe, catch_unwind}; - use datafusion_expr::{ Operator, binary_expr, case, col, in_list, is_null, lit, when, }; - fn restrict_null_predicate_in_modes( - predicate: Expr, - join_cols: &[Column], - ) -> Result<(bool, bool)> { - let auto_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::Auto, - || is_restrict_null_predicate(predicate.clone(), join_cols.iter()), - )?; - - let authoritative_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::AuthoritativeOnly, - || is_restrict_null_predicate(predicate.clone(), join_cols.iter()), - )?; - - Ok((auto_result, authoritative_result)) - } - #[test] fn expr_is_restrict_null_predicate() -> Result<()> { let test_cases = vec![ @@ -409,118 +316,4 @@ mod tests { Ok(()) } - - #[test] - fn syntactic_fast_path_matches_authoritative_evaluator() -> Result<()> { - let test_cases = vec![ - is_null(col("a")), - Expr::IsNotNull(Box::new(col("a"))), - binary_expr(col("a"), Operator::Gt, lit(8i64)), - binary_expr(col("a"), Operator::Eq, lit(ScalarValue::Null)), - binary_expr(col("a"), Operator::And, lit(true)), - binary_expr(col("a"), Operator::Or, lit(false)), - Expr::Not(Box::new(col("a").is_true())), - col("a").is_true(), - col("a").is_false(), - col("a").is_unknown(), - col("a").is_not_true(), - col("a").is_not_false(), - col("a").is_not_unknown(), - col("a").between(lit(1i64), lit(10i64)), - binary_expr( - when(Expr::IsNotNull(Box::new(col("a"))), col("a")) - .otherwise(col("b"))?, - Operator::Gt, - lit(2i64), - ), - case(col("a")) - .when(lit(1i64), lit(true)) - .otherwise(lit(false))?, - case(col("a")) - .when(lit(0i64), lit(false)) - .otherwise(lit(true))?, - binary_expr( - case(col("a")) - .when(lit(0i64), lit(true)) - .otherwise(lit(false))?, - Operator::Or, - lit(false), - ), - binary_expr( - case(lit(1i64)) - .when(lit(1i64), lit(ScalarValue::Null)) - .otherwise(lit(false))?, - Operator::IsNotDistinctFrom, - lit(true), - ), - ]; - - for predicate in test_cases { - let join_cols = predicate.column_refs(); - if let Some(syntactic) = null_restriction::syntactic_restrict_null_predicate( - &predicate, &join_cols, - ) { - let authoritative = authoritative_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - ) - .unwrap_or_else(|error| { - panic!( - "authoritative evaluator failed for predicate `{predicate}`: {error}" - ) - }); - assert_eq!( - syntactic, authoritative, - "syntactic fast path disagrees with authoritative evaluator for predicate: {predicate}", - ); - } - } - - Ok(()) - } - - #[test] - fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> { - let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); - let join_cols_of_predicate = predicate - .column_refs() - .into_iter() - .cloned() - .collect::>(); - let (auto_result, authoritative_result) = - restrict_null_predicate_in_modes(predicate, &join_cols_of_predicate)?; - - assert_eq!(auto_result, authoritative_result); - - Ok(()) - } - - #[test] - fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode() -> Result<()> - { - let predicate = binary_expr(col("a"), Operator::Gt, col("b")); - let join_cols = vec![Column::from_name("a")]; - let (auto_result, authoritative_only_result) = - restrict_null_predicate_in_modes(predicate.clone(), &join_cols)?; - - assert!(!auto_result, "{predicate}"); - assert!(!authoritative_only_result, "{predicate}"); - - Ok(()) - } - - #[test] - fn null_restriction_eval_mode_guard_restores_on_panic() { - set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); - - let result = catch_unwind(AssertUnwindSafe(|| { - with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::AuthoritativeOnly, - || panic!("intentional panic to verify test mode reset"), - ) - })); - - assert!(result.is_err()); - assert_eq!(null_restriction_eval_mode(), NullRestrictionEvalMode::Auto); - } } diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs deleted file mode 100644 index 5f6b9eef794ea..0000000000000 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ /dev/null @@ -1,262 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Syntactic null-restriction evaluator used by optimizer fast paths. - -use std::collections::HashSet; - -use datafusion_common::{Column, ScalarValue}; -use datafusion_expr::{BinaryExpr, Expr, Operator}; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum NullSubstitutionValue { - /// SQL NULL after substituting join columns with NULL. - Null, - /// Known to be non-null, but value is otherwise unknown. - NonNull, - /// A known boolean outcome from SQL three-valued logic. - Boolean(bool), -} - -pub(super) fn syntactic_restrict_null_predicate( - predicate: &Expr, - join_cols: &HashSet<&Column>, -) -> Option { - match syntactic_null_substitution_value(predicate, join_cols) { - Some(NullSubstitutionValue::Boolean(value)) => Some(!value), - Some(NullSubstitutionValue::Null) => Some(true), - Some(NullSubstitutionValue::NonNull) | None => None, - } -} - -fn not(value: Option) -> Option { - match value { - Some(NullSubstitutionValue::Boolean(value)) => { - Some(NullSubstitutionValue::Boolean(!value)) - } - Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), - Some(NullSubstitutionValue::NonNull) | None => None, - } -} - -fn binary_boolean_value( - left: Option, - right: Option, - when_short_circuit: bool, -) -> Option { - let short_circuit = Some(NullSubstitutionValue::Boolean(when_short_circuit)); - let identity = Some(NullSubstitutionValue::Boolean(!when_short_circuit)); - - if left == short_circuit || right == short_circuit { - return short_circuit; - } - - match (left, right) { - (value, other) if value == identity => other, - (other, value) if value == identity => other, - (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { - Some(NullSubstitutionValue::Null) - } - (Some(NullSubstitutionValue::NonNull), _) - | (_, Some(NullSubstitutionValue::NonNull)) - | (None, _) - | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } - } -} - -fn null_check_value( - value: Option, - is_not_null: bool, -) -> Option { - match value { - Some(NullSubstitutionValue::Null) => { - Some(NullSubstitutionValue::Boolean(!is_not_null)) - } - Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { - Some(NullSubstitutionValue::Boolean(is_not_null)) - } - None => None, - } -} - -fn null_if_contains_null( - values: impl IntoIterator>, -) -> Option { - values - .into_iter() - .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) - .then_some(NullSubstitutionValue::Null) -} - -fn strict_null_only( - value: Option, -) -> Option { - value.filter(|value| matches!(value, NullSubstitutionValue::Null)) -} - -fn syntactic_null_substitution_value( - expr: &Expr, - join_cols: &HashSet<&Column>, -) -> Option { - // This evaluator intentionally supports a strict subset of expressions: - // aliases/columns/literals, boolean combinators (NOT/AND/OR), null checks - // (IS [NOT] NULL), BETWEEN, strict-null-preserving unary operators - // (CAST/TRY_CAST/NEGATIVE), LIKE/SIMILAR TO, and binary operators handled in - // `syntactic_binary_value`. - // - // Returning `None` means "defer to the authoritative evaluator" rather than - // "not null-restricting". Any unsupported expression variant must return - // `None` so callers can safely fall back to full expression evaluation. - match expr { - Expr::Alias(alias) => { - syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) - } - Expr::Column(column) => join_cols - .contains(column) - .then_some(NullSubstitutionValue::Null), - Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), - Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), - Expr::Not(expr) => { - not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) - } - Expr::IsNull(expr) => null_check_value( - syntactic_null_substitution_value(expr.as_ref(), join_cols), - false, - ), - Expr::IsNotNull(expr) => null_check_value( - syntactic_null_substitution_value(expr.as_ref(), join_cols), - true, - ), - Expr::Between(between) => null_if_contains_null([ - syntactic_null_substitution_value(between.expr.as_ref(), join_cols), - syntactic_null_substitution_value(between.low.as_ref(), join_cols), - syntactic_null_substitution_value(between.high.as_ref(), join_cols), - ]), - Expr::Cast(cast) => strict_null_only(syntactic_null_substitution_value( - cast.expr.as_ref(), - join_cols, - )), - Expr::TryCast(try_cast) => strict_null_only(syntactic_null_substitution_value( - try_cast.expr.as_ref(), - join_cols, - )), - Expr::Negative(expr) => { - strict_null_only(syntactic_null_substitution_value(expr.as_ref(), join_cols)) - } - Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ - syntactic_null_substitution_value(like.expr.as_ref(), join_cols), - syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), - ]), - Expr::Exists { .. } - | Expr::InList(_) - | Expr::InSubquery(_) - | Expr::SetComparison(_) - | Expr::ScalarSubquery(_) - | Expr::OuterReferenceColumn(_, _) - | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) - | Expr::Unnest(_) - | Expr::GroupingSet(_) - | Expr::WindowFunction(_) - | Expr::ScalarFunction(_) - | Expr::Case(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) => None, - Expr::AggregateFunction(_) => None, - // TODO: remove the next line after `Expr::Wildcard` is removed - #[expect(deprecated)] - Expr::Wildcard { .. } => None, - } -} - -fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { - match value { - _ if value.is_null() => NullSubstitutionValue::Null, - ScalarValue::Boolean(Some(value)) => NullSubstitutionValue::Boolean(*value), - _ => NullSubstitutionValue::NonNull, - } -} - -fn is_strict_null_binary_op(op: Operator) -> bool { - matches!( - op, - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq - | Operator::Plus - | Operator::Minus - | Operator::Multiply - | Operator::Divide - | Operator::Modulo - | Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch - | Operator::LikeMatch - | Operator::ILikeMatch - | Operator::NotLikeMatch - | Operator::NotILikeMatch - | Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft - | Operator::StringConcat - | Operator::AtArrow - | Operator::ArrowAt - | Operator::Arrow - | Operator::LongArrow - | Operator::HashArrow - | Operator::HashLongArrow - | Operator::AtAt - | Operator::IntegerDivide - | Operator::HashMinus - | Operator::AtQuestion - | Operator::Question - | Operator::QuestionAnd - | Operator::QuestionPipe - | Operator::Colon - ) -} - -fn syntactic_binary_value( - binary_expr: &BinaryExpr, - join_cols: &HashSet<&Column>, -) -> Option { - let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); - let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); - - match binary_expr.op { - Operator::And => binary_boolean_value(left, right, false), - Operator::Or => binary_boolean_value(left, right, true), - Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => None, - op => is_strict_null_binary_op(op) - .then(|| null_if_contains_null([left, right])) - .flatten(), - } -} From bacb0e37c37b981b2bfc96013c6c0731c2e81815 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 28 Mar 2026 22:03:39 +0800 Subject: [PATCH 56/63] benchmark 7 improvements, 2 regressions From 847a8fa70d9c23fc3f5208b41fea1240237ebf8e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 28 Mar 2026 22:21:39 +0800 Subject: [PATCH 57/63] Restore syntactic null-restriction fast path Reinstate the deleted fast path for is_restrict_null_predicate in null_restriction.rs. Implement a two-stage evaluation process: return early false for mixed-reference predicates, perform syntactic evaluation for supported join-key-only predicates, and ensure authoritative fallback is applied only when necessary. --- .../optimizer/src/utils/null_restriction.rs | 261 ++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 datafusion/optimizer/src/utils/null_restriction.rs diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs new file mode 100644 index 0000000000000..6e9920af80acc --- /dev/null +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Syntactic null-restriction evaluator used by optimizer fast paths. + +use std::collections::HashSet; + +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::{BinaryExpr, Expr, Operator}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum NullSubstitutionValue { + /// SQL NULL after substituting join columns with NULL. + Null, + /// Known to be non-null, but value is otherwise unknown. + NonNull, + /// A known boolean outcome from SQL three-valued logic. + Boolean(bool), +} + +pub(super) fn all_columns_allowed( + column_refs: &HashSet<&Column>, + allowed_columns: &HashSet<&Column>, +) -> bool { + column_refs + .iter() + .all(|column| allowed_columns.contains(*column)) +} + +pub(super) fn syntactic_restrict_null_predicate( + predicate: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match syntactic_null_substitution_value(predicate, join_cols) { + Some(NullSubstitutionValue::Boolean(value)) => Some(!value), + Some(NullSubstitutionValue::Null) => Some(true), + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +fn not(value: Option) -> Option { + match value { + Some(NullSubstitutionValue::Boolean(value)) => { + Some(NullSubstitutionValue::Boolean(!value)) + } + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +fn binary_boolean_value( + left: Option, + right: Option, + when_short_circuit: bool, +) -> Option { + let short_circuit = Some(NullSubstitutionValue::Boolean(when_short_circuit)); + let identity = Some(NullSubstitutionValue::Boolean(!when_short_circuit)); + + if left == short_circuit || right == short_circuit { + return short_circuit; + } + + match (left, right) { + (value, other) if value == identity => other, + (other, value) if value == identity => other, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + +fn null_check_value( + value: Option, + is_not_null: bool, +) -> Option { + match value { + Some(NullSubstitutionValue::Null) => { + Some(NullSubstitutionValue::Boolean(!is_not_null)) + } + Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { + Some(NullSubstitutionValue::Boolean(is_not_null)) + } + None => None, + } +} + +fn null_if_contains_null( + values: impl IntoIterator>, +) -> Option { + values + .into_iter() + .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) + .then_some(NullSubstitutionValue::Null) +} + +fn strict_null_only( + value: Option, +) -> Option { + value.filter(|value| matches!(value, NullSubstitutionValue::Null)) +} + +fn syntactic_null_substitution_value( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match expr { + Expr::Alias(alias) => { + syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) + } + Expr::Column(column) => join_cols + .contains(column) + .then_some(NullSubstitutionValue::Null), + Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), + Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), + Expr::Not(expr) => { + not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + } + Expr::IsNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + false, + ), + Expr::IsNotNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + ), + Expr::Between(between) => null_if_contains_null([ + syntactic_null_substitution_value(between.expr.as_ref(), join_cols), + syntactic_null_substitution_value(between.low.as_ref(), join_cols), + syntactic_null_substitution_value(between.high.as_ref(), join_cols), + ]), + Expr::Cast(cast) => strict_null_only(syntactic_null_substitution_value( + cast.expr.as_ref(), + join_cols, + )), + Expr::TryCast(try_cast) => strict_null_only(syntactic_null_substitution_value( + try_cast.expr.as_ref(), + join_cols, + )), + Expr::Negative(expr) => { + strict_null_only(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + } + Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ + syntactic_null_substitution_value(like.expr.as_ref(), join_cols), + syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), + ]), + Expr::Exists { .. } + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::Placeholder(_) + | Expr::ScalarVariable(_, _) + | Expr::Unnest(_) + | Expr::GroupingSet(_) + | Expr::WindowFunction(_) + | Expr::ScalarFunction(_) + | Expr::Case(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) => None, + Expr::AggregateFunction(_) => None, + #[expect(deprecated)] + Expr::Wildcard { .. } => None, + } +} + +fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { + match value { + _ if value.is_null() => NullSubstitutionValue::Null, + ScalarValue::Boolean(Some(value)) => NullSubstitutionValue::Boolean(*value), + _ => NullSubstitutionValue::NonNull, + } +} + +fn is_strict_null_binary_op(op: Operator) -> bool { + matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + | Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + | Operator::StringConcat + | Operator::AtArrow + | Operator::ArrowAt + | Operator::Arrow + | Operator::LongArrow + | Operator::HashArrow + | Operator::HashLongArrow + | Operator::AtAt + | Operator::IntegerDivide + | Operator::HashMinus + | Operator::AtQuestion + | Operator::Question + | Operator::QuestionAnd + | Operator::QuestionPipe + | Operator::Colon + ) +} + +fn syntactic_binary_value( + binary_expr: &BinaryExpr, + join_cols: &HashSet<&Column>, +) -> Option { + let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); + let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); + + match binary_expr.op { + Operator::And => binary_boolean_value(left, right, false), + Operator::Or => binary_boolean_value(left, right, true), + Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => None, + op => is_strict_null_binary_op(op) + .then(|| null_if_contains_null([left, right])) + .flatten(), + } +} From 6ef1c84bf23dc9b0fddc55d84496ee32232d8d55 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 30 Mar 2026 12:03:58 +0800 Subject: [PATCH 58/63] benchmark 13 improvements From 5dd668ae128a3776715ed67035a484ae9971acf9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 30 Mar 2026 15:50:35 +0800 Subject: [PATCH 59/63] Fix metric type casing in BaselineMetrics --- datafusion/physical-expr-common/src/metrics/baseline.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr-common/src/metrics/baseline.rs b/datafusion/physical-expr-common/src/metrics/baseline.rs index 0de8e26494931..40280b6bddbba 100644 --- a/datafusion/physical-expr-common/src/metrics/baseline.rs +++ b/datafusion/physical-expr-common/src/metrics/baseline.rs @@ -90,7 +90,7 @@ impl BaselineMetrics { .with_type(super::MetricType::SUMMARY) .output_bytes(partition), output_batches: MetricBuilder::new(metrics) - .with_type(super::MetricType::DEV) + .with_type(super::MetricType::Dev) .output_batches(partition), } } From 80750ead2085b5b54f062b8a6ee68c945715aa90 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 30 Mar 2026 17:21:37 +0800 Subject: [PATCH 60/63] Fix mixed boolean handling in null_restriction Update binary_boolean_value to return None for mixed boolean states instead of asserting. This change allows is_restrict_null_predicate to use the authoritative evaluator, preventing panics from unsupported cases. Add regression tests to ensure proper behavior in unsupported boolean-wrapper scenarios, maintaining consistent functionality between auto mode and authoritative mode. --- datafusion/optimizer/src/utils.rs | 298 +++++++++++++++++- .../optimizer/src/utils/null_restriction.rs | 7 +- 2 files changed, 290 insertions(+), 15 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 2520cbbb5fa61..b2ad3309cb488 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,10 +17,58 @@ //! Utility functions leveraged by the query optimizer rules +mod null_restriction; + use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; +#[cfg(test)] +use std::cell::Cell; + use crate::analyzer::type_coercion::TypeCoercionRewriter; + +/// Null restriction evaluation mode for optimizer tests. +#[cfg(test)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum NullRestrictionEvalMode { + Auto, + AuthoritativeOnly, +} + +#[cfg(test)] +thread_local! { + static NULL_RESTRICTION_EVAL_MODE: Cell = + const { Cell::new(NullRestrictionEvalMode::Auto) }; +} + +#[cfg(test)] +pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { + NULL_RESTRICTION_EVAL_MODE.with(|eval_mode| eval_mode.set(mode)); +} + +#[cfg(test)] +fn null_restriction_eval_mode() -> NullRestrictionEvalMode { + NULL_RESTRICTION_EVAL_MODE.with(Cell::get) +} + +#[cfg(test)] +pub(crate) fn with_null_restriction_eval_mode_for_test( + mode: NullRestrictionEvalMode, + f: impl FnOnce() -> T, +) -> T { + struct NullRestrictionEvalModeReset(NullRestrictionEvalMode); + + impl Drop for NullRestrictionEvalModeReset { + fn drop(&mut self) { + set_null_restriction_eval_mode_for_test(self.0); + } + } + + let previous_mode = null_restriction_eval_mode(); + set_null_restriction_eval_mode_for_test(mode); + let _reset = NullRestrictionEvalModeReset(previous_mode); + f() +} use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; @@ -38,9 +86,13 @@ pub use datafusion_expr::expr_rewriter::NamePreserver; /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { - expr.column_refs() + let column_refs = expr.column_refs(); + // note can't use HashSet::intersect because of different types (owned vs References) + schema_cols .iter() - .all(|column| schema_cols.contains(*column)) + .filter(|c| column_refs.contains(c)) + .count() + == column_refs.len() } pub(crate) fn replace_qualified_name( @@ -78,19 +130,42 @@ pub fn is_restrict_null_predicate<'a>( // Collect join columns so they can be used in both the fast-path check and the // fallback evaluation path below. let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect(); + let column_refs = predicate.column_refs(); + // Fast path: if the predicate references columns outside the join key set, // `evaluate_expr_with_null_column` would fail because the null schema only // contains a placeholder for the join key columns. Callers treat such errors as // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early // and avoid the expensive physical-expression compilation pipeline entirely. - if !predicate - .column_refs() - .iter() - .all(|column| join_cols.contains(*column)) - { + if !null_restriction::all_columns_allowed(&column_refs, &join_cols) { return Ok(false); } + #[cfg(test)] + if matches!( + null_restriction_eval_mode(), + NullRestrictionEvalMode::AuthoritativeOnly + ) { + return authoritative_restrict_null_predicate(predicate, join_cols); + } + + if let Some(is_restricting) = + null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols) + { + #[cfg(debug_assertions)] + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); + } + return Ok(is_restricting); + } + authoritative_restrict_null_predicate(predicate, join_cols) } @@ -147,11 +222,14 @@ fn authoritative_restrict_null_predicate<'a>( ) -> Result { Ok( match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { - ColumnarValue::Array(array) if array.len() == 1 => { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } } - ColumnarValue::Array(_) => false, ColumnarValue::Scalar(scalar) => matches!( scalar, ScalarValue::Boolean(None) @@ -170,6 +248,8 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { #[cfg(test)] mod tests { use super::*; + use std::panic::{AssertUnwindSafe, catch_unwind}; + use datafusion_expr::{ Operator, binary_expr, case, col, in_list, is_null, lit, when, }; @@ -316,4 +396,200 @@ mod tests { Ok(()) } + + #[test] + fn syntactic_fast_path_matches_authoritative_evaluator() -> Result<()> { + let test_cases = vec![ + is_null(col("a")), + Expr::IsNotNull(Box::new(col("a"))), + binary_expr(col("a"), Operator::Gt, lit(8i64)), + binary_expr(col("a"), Operator::Eq, lit(ScalarValue::Null)), + binary_expr(col("a"), Operator::And, lit(true)), + binary_expr(col("a"), Operator::Or, lit(false)), + Expr::Not(Box::new(col("a").is_true())), + col("a").is_true(), + col("a").is_false(), + col("a").is_unknown(), + col("a").is_not_true(), + col("a").is_not_false(), + col("a").is_not_unknown(), + col("a").between(lit(1i64), lit(10i64)), + binary_expr( + when(Expr::IsNotNull(Box::new(col("a"))), col("a")) + .otherwise(col("b"))?, + Operator::Gt, + lit(2i64), + ), + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + binary_expr( + case(lit(1i64)) + .when(lit(1i64), lit(ScalarValue::Null)) + .otherwise(lit(false))?, + Operator::IsNotDistinctFrom, + lit(true), + ), + binary_expr(col("a").is_true(), Operator::And, lit(true)), + binary_expr(col("a").is_false(), Operator::Or, lit(false)), + binary_expr(col("a").is_unknown(), Operator::And, is_null(col("a"))), + binary_expr( + Expr::Not(Box::new(col("a").is_not_unknown())), + Operator::Or, + Expr::IsNotNull(Box::new(col("a"))), + ), + ]; + + for predicate in test_cases { + let join_cols = predicate.column_refs(); + if let Some(syntactic) = null_restriction::syntactic_restrict_null_predicate( + &predicate, &join_cols, + ) { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + ) + .unwrap_or_else(|error| { + panic!( + "authoritative evaluator failed for predicate `{predicate}`: {error}" + ) + }); + assert_eq!( + syntactic, authoritative, + "syntactic fast path disagrees with authoritative evaluator for predicate: {predicate}", + ); + } + } + + Ok(()) + } + + #[test] + fn unsupported_boolean_wrappers_defer_to_authoritative_evaluator() -> Result<()> { + let predicates = vec![ + binary_expr(col("a").is_true(), Operator::And, lit(true)), + binary_expr(col("a").is_false(), Operator::Or, lit(false)), + binary_expr(col("a").is_unknown(), Operator::And, is_null(col("a"))), + binary_expr( + Expr::Not(Box::new(col("a").is_not_unknown())), + Operator::Or, + Expr::IsNotNull(Box::new(col("a"))), + ), + ]; + + for predicate in predicates { + let join_cols = predicate.column_refs(); + assert!( + null_restriction::syntactic_restrict_null_predicate( + &predicate, &join_cols + ) + .is_none(), + "syntactic fast path should defer for predicate: {predicate}", + ); + + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + ) + }, + )?; + + let authoritative_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + ) + }, + )?; + + assert_eq!( + auto_result, authoritative_result, + "auto mode should defer to authoritative evaluation for predicate: {predicate}", + ); + } + + Ok(()) + } + + #[test] + fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> { + let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); + let join_cols_of_predicate = predicate.column_refs(); + + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + ) + }, + )?; + + let authoritative_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || { + is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + ) + }, + )?; + + assert_eq!(auto_result, authoritative_result); + + Ok(()) + } + + #[test] + fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode() -> Result<()> + { + let predicate = binary_expr(col("a"), Operator::Gt, col("b")); + let column_a = Column::from_name("a"); + + let auto_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::Auto, + || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), + )?; + + let authoritative_only_result = with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)), + )?; + + assert!(!auto_result, "{predicate}"); + assert!(!authoritative_only_result, "{predicate}"); + + Ok(()) + } + + #[test] + fn null_restriction_eval_mode_guard_restores_on_panic() { + set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + + let result = catch_unwind(AssertUnwindSafe(|| { + with_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + || panic!("intentional panic to verify test mode reset"), + ) + })); + + assert!(result.is_err()); + assert_eq!(null_restriction_eval_mode(), NullRestrictionEvalMode::Auto); + } } diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 6e9920af80acc..39100c66686f7 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -84,10 +84,9 @@ fn binary_boolean_value( | (_, Some(NullSubstitutionValue::NonNull)) | (None, _) | (_, None) => None, - (left, right) => { - debug_assert_eq!(left, right); - left - } + // Any remaining mixed state is outside the reduced lattice this syntactic + // evaluator can model soundly. Defer to the authoritative evaluator. + _ => None, } } From f8204b6bcb74c7b7b7f6a20f6b491159300bd3ea Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 30 Mar 2026 17:35:54 +0800 Subject: [PATCH 61/63] Fix metric type casing in BaselineMetrics --- datafusion/physical-expr-common/src/metrics/baseline.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr-common/src/metrics/baseline.rs b/datafusion/physical-expr-common/src/metrics/baseline.rs index 40280b6bddbba..0de8e26494931 100644 --- a/datafusion/physical-expr-common/src/metrics/baseline.rs +++ b/datafusion/physical-expr-common/src/metrics/baseline.rs @@ -90,7 +90,7 @@ impl BaselineMetrics { .with_type(super::MetricType::SUMMARY) .output_bytes(partition), output_batches: MetricBuilder::new(metrics) - .with_type(super::MetricType::Dev) + .with_type(super::MetricType::DEV) .output_batches(partition), } } From bf36facddc8cc380356635818f06173d0cf58600 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 31 Mar 2026 11:41:45 +0800 Subject: [PATCH 62/63] formatting --- datafusion/core/benches/sql_planner_extended.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index 7db50b19bf566..767134bb5bafd 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -357,12 +357,8 @@ fn build_non_case_left_join_df_with_push_down_filter( fn find_filter_predicates(plan: &LogicalPlan) -> Vec { match plan { LogicalPlan::Filter(filter) => split_conjunction_owned(filter.predicate.clone()), - LogicalPlan::Projection(projection) => { - find_filter_predicates(projection.input.as_ref()) - } - other => { - panic!("expected benchmark query plan to contain a Filter, found {other:?}") - } + LogicalPlan::Projection(projection) => find_filter_predicates(projection.input.as_ref()), + other => panic!("expected benchmark query plan to contain a Filter, found {other:?}"), } } @@ -379,8 +375,7 @@ fn assert_case_heavy_left_join_inference_candidates( for predicate in predicates { let column_refs = predicate.column_refs(); assert!( - column_refs.contains(&&left_join_key) - || column_refs.contains(&&right_join_key), + column_refs.contains(&&left_join_key) || column_refs.contains(&&right_join_key), "benchmark predicate should reference a join key: {predicate}" ); assert!( From 1f2b598c56787ab0bf97bb26c90993a4f5575529 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 31 Mar 2026 11:43:47 +0800 Subject: [PATCH 63/63] refactor: improve join condition evaluation and filter management - Refactored `classify_join_input` to `strip_plan_wrappers` for clearer handling of logical plan wrappers. - Introduced helper functions `is_scalar_aggregate_subquery` and `is_derived_relation` to check for specific logical plan structures. - Enhanced `is_scalar_subquery_cross_join` to streamline the evaluation logic of joins with scalar subqueries. - Added `should_keep_filter_above_scalar_subquery_cross_join` to manage filter preservation based on join conditions. - Adjusted the predicate handling in `push_down_all_join` to utilize the new helper functions, improving filter application logic. - Updated tests for window operations over scalar subquery cross joins to maintain correct behavior and refactored test setup for clarity. --- datafusion/optimizer/src/push_down_filter.rs | 206 ++++++++++-------- datafusion/optimizer/src/utils.rs | 60 ----- .../optimizer/src/utils/null_restriction.rs | 7 +- 3 files changed, 125 insertions(+), 148 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a4adc18c3ec27..622780cad5085 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,30 +285,53 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } -fn classify_join_input(plan: &LogicalPlan) -> (bool, bool) { +fn strip_plan_wrappers(plan: &LogicalPlan) -> (&LogicalPlan, bool) { match plan { LogicalPlan::SubqueryAlias(subquery_alias) => { - let (is_scalar_aggregate, _) = - classify_join_input(subquery_alias.input.as_ref()); - (is_scalar_aggregate, true) + let (plan, _) = strip_plan_wrappers(subquery_alias.input.as_ref()); + (plan, true) } LogicalPlan::Projection(projection) => { - classify_join_input(projection.input.as_ref()) + let (plan, is_derived_relation) = + strip_plan_wrappers(projection.input.as_ref()); + (plan, is_derived_relation) } - LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false), - _ => (false, false), + _ => (plan, false), } } +fn is_scalar_aggregate_subquery(plan: &LogicalPlan) -> bool { + matches!( + strip_plan_wrappers(plan).0, + LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty() + ) +} + +fn is_derived_relation(plan: &LogicalPlan) -> bool { + strip_plan_wrappers(plan).1 +} + fn is_scalar_subquery_cross_join(join: &Join) -> bool { - let (left_scalar_aggregate, left_is_derived_relation) = - classify_join_input(join.left.as_ref()); - let (right_scalar_aggregate, right_is_derived_relation) = - classify_join_input(join.right.as_ref()); join.on.is_empty() && join.filter.is_none() - && ((left_scalar_aggregate && right_is_derived_relation) - || (right_scalar_aggregate && left_is_derived_relation)) + && ((is_scalar_aggregate_subquery(join.left.as_ref()) + && is_derived_relation(join.right.as_ref())) + || (is_scalar_aggregate_subquery(join.right.as_ref()) + && is_derived_relation(join.left.as_ref()))) +} + +// Keep post-join filters above certain scalar-subquery cross joins to preserve +// behavior for the window-over-scalar-subquery regression shape. +fn should_keep_filter_above_scalar_subquery_cross_join( + join: &Join, + predicate: &Expr, +) -> bool { + if !is_scalar_subquery_cross_join(join) { + return false; + } + + let mut checker = ColumnChecker::new(join.left.schema(), join.right.schema()); + !checker.is_left_only(predicate) && !checker.is_right_only(predicate) } /// examine OR clause to see if any useful clauses can be extracted and push down. @@ -452,15 +475,13 @@ fn push_down_all_join( let mut keep_predicates = vec![]; let mut join_conditions = vec![]; let mut checker = ColumnChecker::new(left_schema, right_schema); - let keep_mixed_scalar_subquery_filters = - is_inner_join && is_scalar_subquery_cross_join(&join); for predicate in predicates { if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join - && !keep_mixed_scalar_subquery_filters + && !should_keep_filter_above_scalar_subquery_cross_join(&join, &predicate) && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate @@ -754,33 +775,36 @@ fn infer_join_predicates_impl< inferred_predicates: &mut InferredPredicates, ) -> Result<()> { for predicate in input_predicates { - let mut join_cols_to_replace = HashMap::new(); - let mut saw_non_replaceable_ref = false; - - for &col in &predicate.column_refs() { - let replacement = join_col_keys.iter().find_map(|(l, r)| { - if ENABLE_LEFT_TO_RIGHT && col == *l { - Some((col, *r)) - } else if ENABLE_RIGHT_TO_LEFT && col == *r { - Some((col, *l)) - } else { - None - } - }); + let column_refs = predicate.column_refs(); + let join_col_replacements: Vec<_> = column_refs + .iter() + .filter_map(|&col| { + join_col_keys.iter().find_map(|(l, r)| { + if ENABLE_LEFT_TO_RIGHT && col == *l { + Some((col, *r)) + } else if ENABLE_RIGHT_TO_LEFT && col == *r { + Some((col, *l)) + } else { + None + } + }) + }) + .collect(); - if let Some((source, target)) = replacement { - join_cols_to_replace.insert(source, target); - } else { - saw_non_replaceable_ref = true; - } + if join_col_replacements.is_empty() { + continue; } - if join_cols_to_replace.is_empty() - || (!inferred_predicates.is_inner_join && saw_non_replaceable_ref) + // For non-inner joins, predicates that reference any non-replaceable + // columns cannot be inferred on the other side. Skip the null-restriction + // helper entirely in that common mixed-reference case. + if !inferred_predicates.is_inner_join + && join_col_replacements.len() != column_refs.len() { continue; } + let join_cols_to_replace = join_col_replacements.into_iter().collect(); inferred_predicates .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; } @@ -1529,53 +1553,6 @@ mod tests { use super::*; - fn scalar_subquery_right_plan() -> Result { - LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) - .project(vec![col("a").alias("acctbal")])? - .aggregate( - Vec::::new(), - vec![avg(col("acctbal")).alias("avg_acctbal")], - )? - .alias("__scalar_sq_1")? - .build() - } - - fn row_number_window_expr() -> Expr { - Expr::from(WindowFunction::new( - WindowFunctionDefinition::WindowUDF( - datafusion_functions_window::row_number::row_number_udwf(), - ), - vec![], - )) - .partition_by(vec![col("s.nation")]) - .order_by(vec![col("s.acctbal").sort(false, true)]) - .build() - .unwrap() - } - - fn window_over_scalar_subquery_cross_join_plan( - with_project_wrapper: bool, - ) -> Result { - let left = { - let builder = LogicalPlanBuilder::from(test_table_scan()?) - .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? - .alias("s")?; - let builder = if with_project_wrapper { - builder.project(vec![col("s.nation"), col("s.acctbal")])? - } else { - builder - }; - builder.build()? - }; - - LogicalPlanBuilder::from(left) - .cross_join(scalar_subquery_right_plan()?)? - .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? - .project(vec![col("s.nation"), col("s.acctbal")])? - .window(vec![row_number_window_expr()])? - .build() - } - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} macro_rules! assert_optimized_plan_equal { @@ -2497,7 +2474,36 @@ mod tests { #[test] fn window_over_scalar_subquery_cross_join_keeps_filter_above_join() -> Result<()> { - let plan = window_over_scalar_subquery_cross_join_plan(false)?; + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? + .alias("s")? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a").alias("acctbal")])? + .aggregate( + Vec::::new(), + vec![avg(col("acctbal")).alias("avg_acctbal")], + )? + .alias("__scalar_sq_1")? + .build()?; + + let window = Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::row_number::row_number_udwf(), + ), + vec![], + )) + .partition_by(vec![col("s.nation")]) + .order_by(vec![col("s.acctbal").sort(false, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? + .project(vec![col("s.nation"), col("s.acctbal")])? + .window(vec![window])? + .build()?; assert_optimized_plan_equal!( plan, @@ -2520,7 +2526,37 @@ mod tests { #[test] fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join() -> Result<()> { - let plan = window_over_scalar_subquery_cross_join_plan(true)?; + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a").alias("nation"), col("b").alias("acctbal")])? + .alias("s")? + .project(vec![col("s.nation"), col("s.acctbal")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a").alias("acctbal")])? + .aggregate( + Vec::::new(), + vec![avg(col("acctbal")).alias("avg_acctbal")], + )? + .alias("__scalar_sq_1")? + .build()?; + + let window = Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::row_number::row_number_udwf(), + ), + vec![], + )) + .partition_by(vec![col("s.nation")]) + .order_by(vec![col("s.acctbal").sort(false, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))? + .project(vec![col("s.nation"), col("s.acctbal")])? + .window(vec![window])? + .build()?; assert_optimized_plan_equal!( plan, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index b2ad3309cb488..329271a067ee8 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -440,14 +440,6 @@ mod tests { Operator::IsNotDistinctFrom, lit(true), ), - binary_expr(col("a").is_true(), Operator::And, lit(true)), - binary_expr(col("a").is_false(), Operator::Or, lit(false)), - binary_expr(col("a").is_unknown(), Operator::And, is_null(col("a"))), - binary_expr( - Expr::Not(Box::new(col("a").is_not_unknown())), - Operator::Or, - Expr::IsNotNull(Box::new(col("a"))), - ), ]; for predicate in test_cases { @@ -474,58 +466,6 @@ mod tests { Ok(()) } - #[test] - fn unsupported_boolean_wrappers_defer_to_authoritative_evaluator() -> Result<()> { - let predicates = vec![ - binary_expr(col("a").is_true(), Operator::And, lit(true)), - binary_expr(col("a").is_false(), Operator::Or, lit(false)), - binary_expr(col("a").is_unknown(), Operator::And, is_null(col("a"))), - binary_expr( - Expr::Not(Box::new(col("a").is_not_unknown())), - Operator::Or, - Expr::IsNotNull(Box::new(col("a"))), - ), - ]; - - for predicate in predicates { - let join_cols = predicate.column_refs(); - assert!( - null_restriction::syntactic_restrict_null_predicate( - &predicate, &join_cols - ) - .is_none(), - "syntactic fast path should defer for predicate: {predicate}", - ); - - let auto_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::Auto, - || { - is_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - ) - }, - )?; - - let authoritative_result = with_null_restriction_eval_mode_for_test( - NullRestrictionEvalMode::AuthoritativeOnly, - || { - is_restrict_null_predicate( - predicate.clone(), - join_cols.iter().copied(), - ) - }, - )?; - - assert_eq!( - auto_result, authoritative_result, - "auto mode should defer to authoritative evaluation for predicate: {predicate}", - ); - } - - Ok(()) - } - #[test] fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> { let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs index 39100c66686f7..6e9920af80acc 100644 --- a/datafusion/optimizer/src/utils/null_restriction.rs +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -84,9 +84,10 @@ fn binary_boolean_value( | (_, Some(NullSubstitutionValue::NonNull)) | (None, _) | (_, None) => None, - // Any remaining mixed state is outside the reduced lattice this syntactic - // evaluator can model soundly. Defer to the authoritative evaluator. - _ => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } } }