Skip to content

Commit ff034f2

Browse files
committed
Refactor utility layers and simplify null handling
Tighten scalar-subquery guard, remove redundant empty-loop scaffolding, and share more of the null-evaluation flow. Simplify small null-restriction helpers and match arms.
1 parent 9b48073 commit ff034f2

File tree

3 files changed

+112
-80
lines changed

3 files changed

+112
-80
lines changed

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -286,21 +286,26 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
286286
}
287287

288288
fn is_scalar_subquery_cross_join(join: &Join) -> bool {
289-
fn classify(plan: &LogicalPlan) -> (bool, bool) {
289+
fn is_scalar_aggregate_or_derived_relation(plan: &LogicalPlan) -> (bool, bool) {
290290
match plan {
291291
LogicalPlan::SubqueryAlias(subquery_alias) => {
292-
let (is_scalar_aggregate, _) = classify(subquery_alias.input.as_ref());
292+
let (is_scalar_aggregate, _) = is_scalar_aggregate_or_derived_relation(
293+
subquery_alias.input.as_ref(),
294+
);
293295
(is_scalar_aggregate, true)
294296
}
295-
LogicalPlan::Projection(projection) => classify(projection.input.as_ref()),
297+
LogicalPlan::Projection(projection) => {
298+
is_scalar_aggregate_or_derived_relation(projection.input.as_ref())
299+
}
296300
LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false),
297301
_ => (false, false),
298302
}
299303
}
300304

301-
let (left_scalar_aggregate, left_is_derived_relation) = classify(join.left.as_ref());
305+
let (left_scalar_aggregate, left_is_derived_relation) =
306+
is_scalar_aggregate_or_derived_relation(join.left.as_ref());
302307
let (right_scalar_aggregate, right_is_derived_relation) =
303-
classify(join.right.as_ref());
308+
is_scalar_aggregate_or_derived_relation(join.right.as_ref());
304309
join.on.is_empty()
305310
&& join.filter.is_none()
306311
&& ((left_scalar_aggregate && right_is_derived_relation)
@@ -451,9 +456,12 @@ fn push_down_all_join(
451456
let keep_mixed_scalar_subquery_filters =
452457
is_inner_join && is_scalar_subquery_cross_join(&join);
453458
for predicate in predicates {
454-
if left_preserved && checker.is_left_only(&predicate) {
459+
let left_only = left_preserved && checker.is_left_only(&predicate);
460+
let right_only =
461+
!left_only && right_preserved && checker.is_right_only(&predicate);
462+
if left_only {
455463
left_push.push(predicate);
456-
} else if right_preserved && checker.is_right_only(&predicate) {
464+
} else if right_only {
457465
right_push.push(predicate);
458466
} else if is_inner_join
459467
&& !keep_mixed_scalar_subquery_filters
@@ -479,43 +487,63 @@ fn push_down_all_join(
479487
let mut on_filter_join_conditions = vec![];
480488
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
481489

482-
if !on_filter.is_empty() {
483-
for on in on_filter {
484-
if on_left_preserved && checker.is_left_only(&on) {
485-
left_push.push(on)
486-
} else if on_right_preserved && checker.is_right_only(&on) {
487-
right_push.push(on)
488-
} else {
489-
on_filter_join_conditions.push(on)
490-
}
490+
for on in on_filter {
491+
if on_left_preserved && checker.is_left_only(&on) {
492+
left_push.push(on)
493+
} else if on_right_preserved && checker.is_right_only(&on) {
494+
right_push.push(on)
495+
} else {
496+
on_filter_join_conditions.push(on)
491497
}
492498
}
493499

494500
// Extract from OR clause, generate new predicates for both side of join if possible.
495501
// We only track the unpushable predicates above.
496-
if left_preserved {
497-
left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
498-
left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
499-
}
500-
if right_preserved {
501-
right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
502-
right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
503-
}
502+
let extend_or_clauses =
503+
|target: &mut Vec<Expr>, filters: &[Expr], schema: &DFSchema, preserved| {
504+
if preserved {
505+
target.extend(extract_or_clauses_for_join(filters, schema));
506+
}
507+
};
508+
extend_or_clauses(
509+
&mut left_push,
510+
&keep_predicates,
511+
left_schema,
512+
left_preserved,
513+
);
514+
extend_or_clauses(
515+
&mut left_push,
516+
&join_conditions,
517+
left_schema,
518+
left_preserved,
519+
);
520+
extend_or_clauses(
521+
&mut right_push,
522+
&keep_predicates,
523+
right_schema,
524+
right_preserved,
525+
);
526+
extend_or_clauses(
527+
&mut right_push,
528+
&join_conditions,
529+
right_schema,
530+
right_preserved,
531+
);
504532

505533
// For predicates from join filter, we should check with if a join side is preserved
506534
// in term of join filtering.
507-
if on_left_preserved {
508-
left_push.extend(extract_or_clauses_for_join(
509-
&on_filter_join_conditions,
510-
left_schema,
511-
));
512-
}
513-
if on_right_preserved {
514-
right_push.extend(extract_or_clauses_for_join(
515-
&on_filter_join_conditions,
516-
right_schema,
517-
));
518-
}
535+
extend_or_clauses(
536+
&mut left_push,
537+
&on_filter_join_conditions,
538+
left_schema,
539+
on_left_preserved,
540+
);
541+
extend_or_clauses(
542+
&mut right_push,
543+
&on_filter_join_conditions,
544+
right_schema,
545+
on_right_preserved,
546+
);
519547

520548
if let Some(predicate) = conjunction(left_push) {
521549
join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));

datafusion/optimizer/src/utils.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ use log::{debug, trace};
3838
/// as it was initially placed here and then moved elsewhere.
3939
pub use datafusion_expr::expr_rewriter::NamePreserver;
4040

41-
#[cfg(test)]
42-
use self::test_eval_mode::*;
43-
4441
/// Returns true if `expr` contains all columns in `schema_cols`
4542
pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet<Column>) -> bool {
4643
expr.column_refs()
@@ -76,7 +73,7 @@ pub fn is_restrict_null_predicate<'a>(
7673
predicate: Expr,
7774
join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
7875
) -> Result<bool> {
79-
if matches!(predicate, Expr::Column(_)) {
76+
if is_column_expr(&predicate) {
8077
return Ok(true);
8178
}
8279

@@ -96,8 +93,8 @@ pub fn is_restrict_null_predicate<'a>(
9693

9794
#[cfg(test)]
9895
if matches!(
99-
null_restriction_eval_mode(),
100-
NullRestrictionEvalMode::AuthoritativeOnly
96+
test_eval_mode::null_restriction_eval_mode(),
97+
test_eval_mode::NullRestrictionEvalMode::AuthoritativeOnly
10198
) {
10299
return authoritative_restrict_null_predicate(predicate, join_cols);
103100
}
@@ -130,16 +127,16 @@ pub fn evaluates_to_null<'a>(
130127
predicate: Expr,
131128
null_columns: impl IntoIterator<Item = &'a Column>,
132129
) -> Result<bool> {
133-
if matches!(predicate, Expr::Column(_)) {
130+
if is_column_expr(&predicate) {
134131
return Ok(true);
135132
}
136133

137-
Ok(
138-
match evaluate_expr_with_null_column(predicate, null_columns)? {
134+
evaluate_with_null_columns(predicate, null_columns, |result| {
135+
Ok(match result {
139136
ColumnarValue::Array(_) => false,
140137
ColumnarValue::Scalar(scalar) => scalar.is_null(),
141-
},
142-
)
138+
})
139+
})
143140
}
144141

145142
fn evaluate_expr_with_null_column<'a>(
@@ -173,8 +170,8 @@ fn authoritative_restrict_null_predicate<'a>(
173170
predicate: Expr,
174171
join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
175172
) -> Result<bool> {
176-
Ok(
177-
match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
173+
evaluate_with_null_columns(predicate, join_cols_of_predicate, |result| {
174+
Ok(match result {
178175
ColumnarValue::Array(array) if array.len() == 1 => {
179176
let boolean_array = as_boolean_array(&array)?;
180177
boolean_array.is_null(0) || !boolean_array.value(0)
@@ -186,15 +183,27 @@ fn authoritative_restrict_null_predicate<'a>(
186183
| ScalarValue::Boolean(Some(false))
187184
| ScalarValue::Null
188185
),
189-
},
190-
)
186+
})
187+
})
191188
}
192189

193190
fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
194191
let mut expr_rewrite = TypeCoercionRewriter { schema };
195192
expr.rewrite(&mut expr_rewrite).data()
196193
}
197194

195+
fn is_column_expr(expr: &Expr) -> bool {
196+
matches!(expr, Expr::Column(_))
197+
}
198+
199+
fn evaluate_with_null_columns<'a, T>(
200+
predicate: Expr,
201+
null_columns: impl IntoIterator<Item = &'a Column>,
202+
f: impl FnOnce(ColumnarValue) -> Result<T>,
203+
) -> Result<T> {
204+
f(evaluate_expr_with_null_column(predicate, null_columns)?)
205+
}
206+
198207
#[cfg(test)]
199208
mod test_eval_mode {
200209
use std::cell::Cell;
@@ -243,6 +252,11 @@ mod tests {
243252
use super::*;
244253
use std::panic::{AssertUnwindSafe, catch_unwind};
245254

255+
use crate::utils::test_eval_mode::{
256+
NullRestrictionEvalMode, null_restriction_eval_mode,
257+
set_null_restriction_eval_mode_for_test,
258+
with_null_restriction_eval_mode_for_test,
259+
};
246260
use datafusion_expr::{
247261
Operator, binary_expr, case, col, in_list, is_null, lit, when,
248262
};

datafusion/optimizer/src/utils/null_restriction.rs

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,11 @@ pub(super) fn syntactic_restrict_null_predicate(
3636
predicate: &Expr,
3737
join_cols: &HashSet<&Column>,
3838
) -> Option<bool> {
39-
syntactic_null_substitution_value(predicate, join_cols).and_then(
40-
|value| match value {
41-
NullSubstitutionValue::Boolean(value) => Some(!value),
42-
NullSubstitutionValue::Null => Some(true),
43-
NullSubstitutionValue::NonNull => None,
44-
},
45-
)
39+
match syntactic_null_substitution_value(predicate, join_cols) {
40+
Some(NullSubstitutionValue::Boolean(value)) => Some(!value),
41+
Some(NullSubstitutionValue::Null) => Some(true),
42+
_ => None,
43+
}
4644
}
4745

4846
fn not(value: Option<NullSubstitutionValue>) -> Option<NullSubstitutionValue> {
@@ -88,13 +86,10 @@ fn null_check_value(
8886
value: Option<NullSubstitutionValue>,
8987
is_not_null: bool,
9088
) -> Option<NullSubstitutionValue> {
91-
match value {
92-
Some(NullSubstitutionValue::Null) => {
93-
Some(NullSubstitutionValue::Boolean(!is_not_null))
94-
}
95-
Some(_) => Some(NullSubstitutionValue::Boolean(is_not_null)),
96-
None => None,
97-
}
89+
value.map(|value| match value {
90+
NullSubstitutionValue::Null => NullSubstitutionValue::Boolean(!is_not_null),
91+
_ => NullSubstitutionValue::Boolean(is_not_null),
92+
})
9893
}
9994

10095
fn null_if_contains_null(
@@ -106,12 +101,6 @@ fn null_if_contains_null(
106101
.then_some(NullSubstitutionValue::Null)
107102
}
108103

109-
fn strict_null_only(
110-
value: Option<NullSubstitutionValue>,
111-
) -> Option<NullSubstitutionValue> {
112-
value.filter(|value| matches!(value, NullSubstitutionValue::Null))
113-
}
114-
115104
fn syntactic_null_substitution_value(
116105
expr: &Expr,
117106
join_cols: &HashSet<&Column>,
@@ -150,16 +139,17 @@ fn syntactic_null_substitution_value(
150139
syntactic_null_substitution_value(between.low.as_ref(), join_cols),
151140
syntactic_null_substitution_value(between.high.as_ref(), join_cols),
152141
]),
153-
Expr::Cast(cast) => strict_null_only(syntactic_null_substitution_value(
154-
cast.expr.as_ref(),
155-
join_cols,
156-
)),
157-
Expr::TryCast(try_cast) => strict_null_only(syntactic_null_substitution_value(
158-
try_cast.expr.as_ref(),
159-
join_cols,
160-
)),
142+
Expr::Cast(cast) => {
143+
syntactic_null_substitution_value(cast.expr.as_ref(), join_cols)
144+
.filter(|value| matches!(value, NullSubstitutionValue::Null))
145+
}
146+
Expr::TryCast(try_cast) => {
147+
syntactic_null_substitution_value(try_cast.expr.as_ref(), join_cols)
148+
.filter(|value| matches!(value, NullSubstitutionValue::Null))
149+
}
161150
Expr::Negative(expr) => {
162-
strict_null_only(syntactic_null_substitution_value(expr.as_ref(), join_cols))
151+
syntactic_null_substitution_value(expr.as_ref(), join_cols)
152+
.filter(|value| matches!(value, NullSubstitutionValue::Null))
163153
}
164154
Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([
165155
syntactic_null_substitution_value(like.expr.as_ref(), join_cols),

0 commit comments

Comments
 (0)