Skip to content

Commit 02cf97d

Browse files
committed
Address PR review feebdack
1 parent 6323386 commit 02cf97d

3 files changed

Lines changed: 82 additions & 81 deletions

File tree

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2499,8 +2499,12 @@ impl Filter {
24992499
///
25002500
/// Skips the type-checking and dealiasing done in [Self::try_new].
25012501
/// For internal use in DataFusion only.
2502+
///
2503+
/// **Preconditions:**
2504+
/// - the `predicate` expression returns a boolean value
2505+
/// - the `predicate` expression is not aliased
25022506
#[doc(hidden)]
2503-
pub fn new_unchecked(predicate: Expr, input: Arc<LogicalPlan>) -> Self {
2507+
pub fn new(predicate: Expr, input: Arc<LogicalPlan>) -> Self {
25042508
Self { predicate, input }
25052509
}
25062510

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 54 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -414,15 +414,15 @@ fn push_down_all_join(
414414
// 3) should be kept as filter conditions
415415
let left_schema = join.left.schema();
416416
let right_schema = join.right.schema();
417-
let left_schema_columns = schema_columns(left_schema);
418-
let right_schema_columns = schema_columns(right_schema);
417+
418+
let left_schema_columns = schema_columns(left_schema.as_ref());
419+
let right_schema_columns = schema_columns(right_schema.as_ref());
419420

420421
let mut left_push = vec![];
421422
let mut right_push = vec![];
422423
let mut keep_predicates = vec![];
423424
let mut join_conditions = vec![];
424425
let mut checker = ColumnChecker::new(left_schema, right_schema);
425-
426426
for predicate in predicates {
427427
if left_preserved && checker.is_left_only(&predicate) {
428428
left_push.push(predicate);
@@ -470,7 +470,6 @@ fn push_down_all_join(
470470
&left_schema_columns,
471471
));
472472
}
473-
474473
if right_preserved {
475474
right_push.extend(extract_or_clauses_for_join(
476475
&keep_predicates,
@@ -483,14 +482,13 @@ fn push_down_all_join(
483482
}
484483

485484
// For predicates from join filter, we should check with if a join side is preserved
486-
// in terms of join filtering.
485+
// in term of join filtering.
487486
if on_left_preserved {
488487
left_push.extend(extract_or_clauses_for_join(
489488
&on_filter_join_conditions,
490489
&left_schema_columns,
491490
));
492491
}
493-
494492
if on_right_preserved {
495493
right_push.extend(extract_or_clauses_for_join(
496494
&on_filter_join_conditions,
@@ -512,15 +510,11 @@ fn push_down_all_join(
512510
}
513511

514512
if let Some(predicate) = conjunction(left_push) {
515-
join.left = Arc::new(LogicalPlan::Filter(Filter::new_unchecked(
516-
predicate, join.left,
517-
)));
513+
join.left = Arc::new(LogicalPlan::Filter(Filter::new(predicate, join.left)));
518514
}
519515

520516
if let Some(predicate) = conjunction(right_push) {
521-
join.right = Arc::new(LogicalPlan::Filter(Filter::new_unchecked(
522-
predicate, join.right,
523-
)));
517+
join.right = Arc::new(LogicalPlan::Filter(Filter::new(predicate, join.right)));
524518
}
525519

526520
// wrap the join on the filter whose predicates must be kept, if any
@@ -802,15 +796,13 @@ impl OptimizerRule for PushDownFilter {
802796
new_predicates.len()
803797
);
804798
}
805-
806799
if old_predicate_len != new_predicates.len() {
807-
if let Some(new_predicate) = conjunction(new_predicates) {
808-
filter.predicate = new_predicate;
809-
} else {
800+
let Some(new_predicate) = conjunction(new_predicates) else {
810801
// new_predicates is empty - remove the filter entirely
811802
// Return the child plan without the filter
812803
return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
813-
}
804+
};
805+
filter.predicate = new_predicate;
814806
}
815807

816808
// If the child has a fetch (limit) or skip (offset), pushing a filter
@@ -975,7 +967,7 @@ impl OptimizerRule for PushDownFilter {
975967

976968
let push_predicate =
977969
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
978-
inputs.push(Arc::new(LogicalPlan::Filter(Filter::new_unchecked(
970+
inputs.push(Arc::new(LogicalPlan::Filter(Filter::new(
979971
push_predicate,
980972
input,
981973
))))
@@ -1001,7 +993,7 @@ impl OptimizerRule for PushDownFilter {
1001993
let mut push_predicates = vec![];
1002994
for expr in predicates {
1003995
let cols = expr.column_refs();
1004-
if cols.iter().all(|col| group_expr_columns.contains(col)) {
996+
if cols.iter().all(|c| group_expr_columns.contains(c)) {
1005997
push_predicates.push(replace_cols_by_name(expr, &replace_map)?);
1006998
} else {
1007999
keep_predicates.push(expr);
@@ -1044,8 +1036,8 @@ impl OptimizerRule for PushDownFilter {
10441036
let potential_partition_keys = window
10451037
.window_expr
10461038
.iter()
1047-
.map(|expr| {
1048-
match expr {
1039+
.map(|e| {
1040+
match e {
10491041
Expr::WindowFunction(window_func) => {
10501042
extract_partition_keys(window_func)
10511043
}
@@ -1122,7 +1114,7 @@ impl OptimizerRule for PushDownFilter {
11221114
// Check which non-volatile filters are supported by source
11231115
let supported_filters = scan
11241116
.source
1125-
.supports_filters_pushdown(&non_volatile_filters)?;
1117+
.supports_filters_pushdown(non_volatile_filters.as_slice())?;
11261118
assert_eq_or_internal_err!(
11271119
non_volatile_filters.len(),
11281120
supported_filters.len(),
@@ -1148,7 +1140,7 @@ impl OptimizerRule for PushDownFilter {
11481140
.map(|(&pred, _)| pred);
11491141

11501142
// Add new scan filters
1151-
let new_scan_filters = scan
1143+
let new_scan_filters: Vec<Expr> = scan
11521144
.filters
11531145
.iter()
11541146
.chain(new_scan_filters)
@@ -1189,7 +1181,6 @@ impl OptimizerRule for PushDownFilter {
11891181
filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
11901182
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
11911183
}
1192-
11931184
let prevent_cols =
11941185
extension_plan.node.prevent_predicate_push_down_columns();
11951186

@@ -1203,7 +1194,7 @@ impl OptimizerRule for PushDownFilter {
12031194
!expr
12041195
.column_refs()
12051196
.iter()
1206-
.any(|col| prevent_cols.contains(&col.name))
1197+
.any(|c| prevent_cols.contains(&c.name))
12071198
})
12081199
.collect();
12091200

@@ -1220,10 +1211,10 @@ impl OptimizerRule for PushDownFilter {
12201211
.into_iter()
12211212
.zip(split_conjunction_owned(filter.predicate))
12221213
{
1223-
if push {
1224-
push_predicates.push(expr)
1214+
if !push {
1215+
keep_predicates.push(expr);
12251216
} else {
1226-
keep_predicates.push(expr)
1217+
push_predicates.push(expr);
12271218
}
12281219
}
12291220

@@ -1234,7 +1225,7 @@ impl OptimizerRule for PushDownFilter {
12341225
.inputs()
12351226
.into_iter()
12361227
.map(|child| {
1237-
LogicalPlan::Filter(Filter::new_unchecked(
1228+
LogicalPlan::Filter(Filter::new(
12381229
predicate.clone(),
12391230
Arc::new(child.clone()),
12401231
))
@@ -1317,7 +1308,7 @@ fn rewrite_projection(
13171308
let projection = if let Some(expr) = conjunction(push_predicates) {
13181309
// re-write all filters based on this projection
13191310
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1320-
projection.input = Arc::new(LogicalPlan::Filter(Filter::new_unchecked(
1311+
projection.input = Arc::new(LogicalPlan::Filter(Filter::new(
13211312
replace_cols_by_name(expr, &pushable_map)?,
13221313
projection.input,
13231314
)));
@@ -1330,6 +1321,14 @@ fn rewrite_projection(
13301321
Ok((projection, keep_predicates))
13311322
}
13321323

1324+
/// Creates a new LogicalPlan::Filter node.
1325+
///
1326+
/// Deprecated: use [`Filter::try_new`] directly.
1327+
#[deprecated]
1328+
pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1329+
Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1330+
}
1331+
13331332
impl PushDownFilter {
13341333
#[expect(missing_docs)]
13351334
pub fn new() -> Self {
@@ -1355,12 +1354,12 @@ where
13551354

13561355
/// replaces columns by its name on the projection.
13571356
pub fn replace_cols_by_name(
1358-
expr: Expr,
1357+
e: Expr,
13591358
replace_map: &HashMap<String, impl AsRef<Expr>>,
13601359
) -> Result<Expr> {
1361-
expr.transform_up(|expr| {
1362-
if let Expr::Column(col) = &expr
1363-
&& let Some(new_expr) = replace_map.get(&col.flat_name())
1360+
e.transform_up(|expr| {
1361+
if let Expr::Column(c) = &expr
1362+
&& let Some(new_expr) = replace_map.get(&c.flat_name())
13641363
{
13651364
Ok(Transformed::yes(new_expr.as_ref().clone()))
13661365
} else {
@@ -1380,9 +1379,9 @@ fn unalias(expr: &Expr) -> &Expr {
13801379
}
13811380

13821381
/// check whether the expression uses the columns in `check_map`.
1383-
fn contain<T>(expr: &Expr, check_map: &HashMap<String, T>) -> bool {
1382+
fn contain<T>(e: &Expr, check_map: &HashMap<String, T>) -> bool {
13841383
let mut is_contain = false;
1385-
expr.apply(|expr| {
1384+
e.apply(|expr| {
13861385
if let Expr::Column(c) = &expr
13871386
&& check_map.contains_key(&c.flat_name())
13881387
{
@@ -1398,7 +1397,7 @@ fn contain<T>(expr: &Expr, check_map: &HashMap<String, T>) -> bool {
13981397

13991398
fn with_filters(predicates: Vec<Expr>, plan: LogicalPlan) -> LogicalPlan {
14001399
if let Some(predicate) = conjunction(predicates) {
1401-
LogicalPlan::Filter(Filter::new_unchecked(predicate, Arc::new(plan)))
1400+
LogicalPlan::Filter(Filter::new(predicate, Arc::new(plan)))
14021401
} else {
14031402
plan
14041403
}
@@ -1476,6 +1475,17 @@ mod tests {
14761475
}};
14771476
}
14781477

1478+
/// For testing that we don't return [Transformed::yes] when not necessary,
1479+
/// as it triggers rebuilding parent plan nodes.
1480+
macro_rules! assert_plan_not_transformed {
1481+
($plan:expr) => {{
1482+
let transformed = PushDownFilter::new()
1483+
.rewrite($plan, &OptimizerContext::new())
1484+
.expect("failed to optimize plan");
1485+
assert!(!transformed.transformed);
1486+
}};
1487+
}
1488+
14791489
#[test]
14801490
fn filter_before_projection() -> Result<()> {
14811491
let table_scan = test_table_scan()?;
@@ -1624,10 +1634,8 @@ mod tests {
16241634
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
16251635
.filter(col("b").gt(lit(10i64)))?
16261636
.build()?;
1627-
let transformed = PushDownFilter::new()
1628-
.rewrite(plan.clone(), &OptimizerContext::new())
1629-
.expect("failed to optimize plan");
1630-
assert!(!transformed.transformed);
1637+
assert_plan_not_transformed!(plan.clone());
1638+
16311639
// filter of aggregate is after aggregation since they are non-commutative
16321640
assert_optimized_plan_equal!(
16331641
plan,
@@ -1820,10 +1828,7 @@ mod tests {
18201828
.window(vec![window])?
18211829
.filter(col("c").gt(lit(10i64)))?
18221830
.build()?;
1823-
let transformed = PushDownFilter::new()
1824-
.rewrite(plan.clone(), &OptimizerContext::new())
1825-
.expect("failed to optimize plan");
1826-
assert!(!transformed.transformed);
1831+
assert_plan_not_transformed!(plan.clone());
18271832

18281833
assert_optimized_plan_equal!(
18291834
plan,
@@ -3049,10 +3054,7 @@ mod tests {
30493054
Some(filter),
30503055
)?
30513056
.build()?;
3052-
let transformed = PushDownFilter::new()
3053-
.rewrite(plan.clone(), &OptimizerContext::new())
3054-
.expect("failed to optimize plan");
3055-
assert!(!transformed.transformed);
3057+
assert_plan_not_transformed!(plan.clone());
30563058

30573059
// not part of the test, just good to know:
30583060
assert_snapshot!(plan,
@@ -3163,16 +3165,12 @@ mod tests {
31633165
.rewrite(plan, &OptimizerContext::new())
31643166
.expect("failed to optimize plan");
31653167
assert!(optimized.transformed);
3166-
3167-
let optimized_again = PushDownFilter::new()
3168-
.rewrite(optimized.data.clone(), &OptimizerContext::new())
3169-
.expect("failed to optimize plan");
3170-
assert!(!optimized_again.transformed);
3168+
assert_plan_not_transformed!(optimized.data.clone());
31713169

31723170
// Optimizing the same plan multiple times should produce the same plan
31733171
// each time.
31743172
assert_optimized_plan_equal!(
3175-
optimized_again.data,
3173+
optimized.data,
31763174
@r"
31773175
Filter: a = Int64(1)
31783176
TableScan: test, partial_filters=[a = Int64(1)]
@@ -3184,11 +3182,7 @@ mod tests {
31843182
fn filter_with_table_provider_unsupported() -> Result<()> {
31853183
let plan =
31863184
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3187-
3188-
let transformed = PushDownFilter::new()
3189-
.rewrite(plan.clone(), &OptimizerContext::new())
3190-
.expect("failed to optimize plan");
3191-
assert!(!transformed.transformed);
3185+
assert_plan_not_transformed!(plan.clone());
31923186

31933187
assert_optimized_plan_equal!(
31943188
plan,

datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,30 @@ pub fn simplify_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
5252
let mut other_predicates = Vec::new();
5353

5454
for pred in predicates {
55-
use Operator::*;
56-
57-
if let Expr::BinaryExpr(BinaryExpr {
58-
left,
59-
op: Gt | GtEq | Lt | LtEq | Eq,
60-
right,
61-
}) = &pred
62-
{
63-
if let (Some(col), Some(_)) =
64-
(extract_column_from_expr(left), right.as_literal())
65-
{
66-
column_predicates.entry(col).or_default().push(pred);
67-
} else if let (Some(_), Some(col)) =
68-
(left.as_literal(), extract_column_from_expr(right))
69-
{
70-
column_predicates.entry(col).or_default().push(pred);
71-
} else {
72-
other_predicates.push(pred);
55+
match &pred {
56+
Expr::BinaryExpr(BinaryExpr {
57+
left,
58+
op:
59+
Operator::Gt
60+
| Operator::GtEq
61+
| Operator::Lt
62+
| Operator::LtEq
63+
| Operator::Eq,
64+
right,
65+
}) => {
66+
if let (Some(col), Some(_)) =
67+
(extract_column_from_expr(left), right.as_literal())
68+
{
69+
column_predicates.entry(col).or_default().push(pred);
70+
} else if let (Some(_), Some(col)) =
71+
(left.as_literal(), extract_column_from_expr(right))
72+
{
73+
column_predicates.entry(col).or_default().push(pred);
74+
} else {
75+
other_predicates.push(pred);
76+
}
7377
}
74-
} else {
75-
other_predicates.push(pred)
78+
_ => other_predicates.push(pred),
7679
}
7780
}
7881

0 commit comments

Comments
 (0)