Skip to content

Commit 787d0a4

Browse files
committed
Refactor push_down_filter for simplicity and efficiency
Eliminate JoinInputShape/PredicateDestination layers and integrate scalar-subquery handling into the main join-pushdown flow. Revert join-inference replacement logic to an explicit loop. Remove unnecessary helper functions and simplify the null-restriction evaluator for cleaner code.
1 parent af77fe4 commit 787d0a4

File tree

3 files changed

+76
-157
lines changed

3 files changed

+76
-157
lines changed

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 54 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -285,81 +285,30 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
285285
Ok(is_evaluate)
286286
}
287287

288-
#[derive(Clone, Copy)]
289-
struct JoinInputShape<'a> {
290-
base_plan: &'a LogicalPlan,
291-
is_derived_relation: bool,
292-
}
293-
294-
fn classify_join_input(plan: &LogicalPlan) -> JoinInputShape<'_> {
288+
fn classify_join_input(plan: &LogicalPlan) -> (bool, bool) {
295289
match plan {
296290
LogicalPlan::SubqueryAlias(subquery_alias) => {
297-
let JoinInputShape { base_plan, .. } =
291+
let (is_scalar_aggregate, _) =
298292
classify_join_input(subquery_alias.input.as_ref());
299-
JoinInputShape {
300-
base_plan,
301-
is_derived_relation: true,
302-
}
293+
(is_scalar_aggregate, true)
303294
}
304295
LogicalPlan::Projection(projection) => {
305-
let shape = classify_join_input(projection.input.as_ref());
306-
JoinInputShape {
307-
is_derived_relation: shape.is_derived_relation,
308-
..shape
309-
}
296+
classify_join_input(projection.input.as_ref())
310297
}
311-
_ => JoinInputShape {
312-
base_plan: plan,
313-
is_derived_relation: false,
314-
},
298+
LogicalPlan::Aggregate(aggregate) => (aggregate.group_expr.is_empty(), false),
299+
_ => (false, false),
315300
}
316301
}
317302

318-
fn is_scalar_aggregate_subquery(shape: JoinInputShape<'_>) -> bool {
319-
matches!(
320-
shape.base_plan,
321-
LogicalPlan::Aggregate(aggregate) if aggregate.group_expr.is_empty()
322-
)
323-
}
324-
325303
fn is_scalar_subquery_cross_join(join: &Join) -> bool {
326-
let left_shape = classify_join_input(join.left.as_ref());
327-
let right_shape = classify_join_input(join.right.as_ref());
304+
let (left_scalar_aggregate, left_is_derived_relation) =
305+
classify_join_input(join.left.as_ref());
306+
let (right_scalar_aggregate, right_is_derived_relation) =
307+
classify_join_input(join.right.as_ref());
328308
join.on.is_empty()
329309
&& join.filter.is_none()
330-
&& ((is_scalar_aggregate_subquery(left_shape) && right_shape.is_derived_relation)
331-
|| (is_scalar_aggregate_subquery(right_shape)
332-
&& left_shape.is_derived_relation))
333-
}
334-
335-
// Keep post-join filters above certain scalar-subquery cross joins to preserve
336-
// behavior for the window-over-scalar-subquery regression shape.
337-
fn should_keep_filter_above_scalar_subquery_cross_join(
338-
mut checker: ColumnChecker<'_>,
339-
predicate: &Expr,
340-
) -> bool {
341-
!checker.is_left_only(predicate) && !checker.is_right_only(predicate)
342-
}
343-
344-
enum PredicateDestination {
345-
Left,
346-
Right,
347-
Keep,
348-
}
349-
350-
fn classify_predicate_destination(
351-
checker: &mut ColumnChecker<'_>,
352-
predicate: &Expr,
353-
allow_left: bool,
354-
allow_right: bool,
355-
) -> PredicateDestination {
356-
if allow_left && checker.is_left_only(predicate) {
357-
PredicateDestination::Left
358-
} else if allow_right && checker.is_right_only(predicate) {
359-
PredicateDestination::Right
360-
} else {
361-
PredicateDestination::Keep
362-
}
310+
&& ((left_scalar_aggregate && right_is_derived_relation)
311+
|| (right_scalar_aggregate && left_is_derived_relation))
363312
}
364313

365314
/// examine OR clause to see if any useful clauses can be extracted and push down.
@@ -506,41 +455,28 @@ fn push_down_all_join(
506455
let keep_mixed_scalar_subquery_filters =
507456
is_inner_join && is_scalar_subquery_cross_join(&join);
508457
for predicate in predicates {
509-
match classify_predicate_destination(
510-
&mut checker,
511-
&predicate,
512-
left_preserved,
513-
right_preserved,
514-
) {
515-
PredicateDestination::Left => left_push.push(predicate),
516-
PredicateDestination::Right => right_push.push(predicate),
517-
PredicateDestination::Keep => {
518-
let should_keep_above_join = keep_mixed_scalar_subquery_filters
519-
&& should_keep_filter_above_scalar_subquery_cross_join(
520-
ColumnChecker::new(left_schema, right_schema),
521-
&predicate,
522-
);
523-
524-
if is_inner_join
525-
&& !should_keep_above_join
526-
&& can_evaluate_as_join_condition(&predicate)?
527-
{
528-
// Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
529-
// and convert to the join on condition
530-
join_conditions.push(predicate);
531-
} else {
532-
keep_predicates.push(predicate);
533-
}
534-
}
458+
if left_preserved && checker.is_left_only(&predicate) {
459+
left_push.push(predicate);
460+
} else if right_preserved && checker.is_right_only(&predicate) {
461+
right_push.push(predicate);
462+
} else if is_inner_join
463+
&& !keep_mixed_scalar_subquery_filters
464+
&& can_evaluate_as_join_condition(&predicate)?
465+
{
466+
// Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
467+
// and convert to the join on condition
468+
join_conditions.push(predicate);
469+
} else {
470+
keep_predicates.push(predicate);
535471
}
536472
}
537473

538474
// Push predicates inferred from the join expression
539475
for predicate in inferred_join_predicates {
540-
match classify_predicate_destination(&mut checker, &predicate, true, true) {
541-
PredicateDestination::Left => left_push.push(predicate),
542-
PredicateDestination::Right => right_push.push(predicate),
543-
PredicateDestination::Keep => {}
476+
if checker.is_left_only(&predicate) {
477+
left_push.push(predicate);
478+
} else if checker.is_right_only(&predicate) {
479+
right_push.push(predicate);
544480
}
545481
}
546482

@@ -549,15 +485,12 @@ fn push_down_all_join(
549485

550486
if !on_filter.is_empty() {
551487
for on in on_filter {
552-
match classify_predicate_destination(
553-
&mut checker,
554-
&on,
555-
on_left_preserved,
556-
on_right_preserved,
557-
) {
558-
PredicateDestination::Left => left_push.push(on),
559-
PredicateDestination::Right => right_push.push(on),
560-
PredicateDestination::Keep => on_filter_join_conditions.push(on),
488+
if on_left_preserved && checker.is_left_only(&on) {
489+
left_push.push(on)
490+
} else if on_right_preserved && checker.is_right_only(&on) {
491+
right_push.push(on)
492+
} else {
493+
on_filter_join_conditions.push(on)
561494
}
562495
}
563496
}
@@ -821,24 +754,26 @@ fn infer_join_predicates_impl<
821754
inferred_predicates: &mut InferredPredicates,
822755
) -> Result<()> {
823756
for predicate in input_predicates {
824-
let column_refs = predicate.column_refs();
757+
let mut join_cols_to_replace = HashMap::new();
825758
let mut saw_non_replaceable_ref = false;
826-
let join_cols_to_replace = column_refs
827-
.iter()
828-
.filter_map(|&col| {
829-
let replacement = join_col_keys.iter().find_map(|(l, r)| {
830-
if ENABLE_LEFT_TO_RIGHT && col == *l {
831-
Some((col, *r))
832-
} else if ENABLE_RIGHT_TO_LEFT && col == *r {
833-
Some((col, *l))
834-
} else {
835-
None
836-
}
837-
});
838-
saw_non_replaceable_ref |= replacement.is_none();
839-
replacement
840-
})
841-
.collect::<HashMap<_, _>>();
759+
760+
for &col in &predicate.column_refs() {
761+
let replacement = join_col_keys.iter().find_map(|(l, r)| {
762+
if ENABLE_LEFT_TO_RIGHT && col == *l {
763+
Some((col, *r))
764+
} else if ENABLE_RIGHT_TO_LEFT && col == *r {
765+
Some((col, *l))
766+
} else {
767+
None
768+
}
769+
});
770+
771+
if let Some((source, target)) = replacement {
772+
join_cols_to_replace.insert(source, target);
773+
} else {
774+
saw_non_replaceable_ref = true;
775+
}
776+
}
842777

843778
if join_cols_to_replace.is_empty()
844779
|| (!inferred_predicates.is_inner_join && saw_non_replaceable_ref)

datafusion/optimizer/src/utils.rs

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ use self::test_eval_mode::{
4646

4747
/// Returns true if `expr` contains all columns in `schema_cols`
4848
pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet<Column>) -> bool {
49-
column_refs_all_in(&expr.column_refs(), |column| schema_cols.contains(column))
49+
expr.column_refs()
50+
.iter()
51+
.all(|column| schema_cols.contains(*column))
5052
}
5153

5254
pub(crate) fn replace_qualified_name(
@@ -70,13 +72,6 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) {
7072
trace!("{description}::\n{}\n", plan.display_indent_schema());
7173
}
7274

73-
pub(super) fn column_refs_all_in<'a>(
74-
column_refs: &HashSet<&'a Column>,
75-
mut contains: impl FnMut(&Column) -> bool,
76-
) -> bool {
77-
column_refs.iter().all(|column| contains(column))
78-
}
79-
8075
/// Determine whether a predicate can restrict NULLs. e.g.
8176
/// `c0 > 8` return true;
8277
/// `c0 IS NULL` return false.
@@ -98,7 +93,7 @@ pub fn is_restrict_null_predicate<'a>(
9893
// contains a placeholder for the join key columns. Callers treat such errors as
9994
// non-restricting (false) via `matches!(_, Ok(true))`, so we return false early
10095
// and avoid the expensive physical-expression compilation pipeline entirely.
101-
if !column_refs_all_in(&column_refs, |column| join_cols.contains(&column)) {
96+
if !column_refs.iter().all(|column| join_cols.contains(*column)) {
10297
return Ok(false);
10398
}
10499

@@ -181,31 +176,28 @@ fn authoritative_restrict_null_predicate<'a>(
181176
predicate: Expr,
182177
join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
183178
) -> Result<bool> {
184-
evaluate_expr_with_null_column(predicate, join_cols_of_predicate)
185-
.and_then(is_false_or_null_boolean_result)
179+
Ok(
180+
match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
181+
ColumnarValue::Array(array) if array.len() == 1 => {
182+
let boolean_array = as_boolean_array(&array)?;
183+
boolean_array.is_null(0) || !boolean_array.value(0)
184+
}
185+
ColumnarValue::Array(_) => false,
186+
ColumnarValue::Scalar(scalar) => matches!(
187+
scalar,
188+
ScalarValue::Boolean(None)
189+
| ScalarValue::Boolean(Some(false))
190+
| ScalarValue::Null
191+
),
192+
},
193+
)
186194
}
187195

188196
fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
189197
let mut expr_rewrite = TypeCoercionRewriter { schema };
190198
expr.rewrite(&mut expr_rewrite).data()
191199
}
192200

193-
fn is_false_or_null_boolean_result(result: ColumnarValue) -> Result<bool> {
194-
Ok(match result {
195-
ColumnarValue::Array(array) if array.len() == 1 => {
196-
let boolean_array = as_boolean_array(&array)?;
197-
boolean_array.is_null(0) || !boolean_array.value(0)
198-
}
199-
ColumnarValue::Array(_) => false,
200-
ColumnarValue::Scalar(scalar) => matches!(
201-
scalar,
202-
ScalarValue::Boolean(None)
203-
| ScalarValue::Boolean(Some(false))
204-
| ScalarValue::Null
205-
),
206-
})
207-
}
208-
209201
#[cfg(test)]
210202
mod test_eval_mode {
211203
use std::cell::Cell;

datafusion/optimizer/src/utils/null_restriction.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,13 @@ enum NullSubstitutionValue {
3232
Boolean(bool),
3333
}
3434

35-
impl NullSubstitutionValue {
36-
fn is_null(self) -> bool {
37-
matches!(self, Self::Null)
38-
}
39-
}
40-
4135
pub(super) fn syntactic_restrict_null_predicate(
4236
predicate: &Expr,
4337
join_cols: &HashSet<&Column>,
4438
) -> Option<bool> {
4539
match syntactic_null_substitution_value(predicate, join_cols) {
46-
Some(NullSubstitutionValue::Boolean(true)) => Some(false),
47-
Some(NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null) => {
48-
Some(true)
49-
}
40+
Some(NullSubstitutionValue::Boolean(value)) => Some(!value),
41+
Some(NullSubstitutionValue::Null) => Some(true),
5042
Some(NullSubstitutionValue::NonNull) | None => None,
5143
}
5244
}
@@ -117,7 +109,7 @@ fn null_if_contains_null(
117109
fn strict_null_only(
118110
value: Option<NullSubstitutionValue>,
119111
) -> Option<NullSubstitutionValue> {
120-
value.filter(|value| value.is_null())
112+
value.filter(|value| matches!(value, NullSubstitutionValue::Null))
121113
}
122114

123115
fn syntactic_null_substitution_value(

0 commit comments

Comments
 (0)