Skip to content

Commit 37bcc07

Browse files
committed
Refactor tests for clarity and reusability
Extract shared helpers in push_down_filter_regressions.rs and push_down_filter.rs to reduce code duplication. Consolidate optimizer-delta test assertions and create specific plan builders for common expressions. Add a utility in utils.rs to evaluate predicates under different null-restriction modes, streamlining mode-comparison tests and enhancing maintainability.
1 parent a3bcb57 commit 37bcc07

File tree

3 files changed

+90
-124
lines changed

3 files changed

+90
-124
lines changed

datafusion/core/tests/sql/push_down_filter_regressions.rs

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ const WINDOW_SCALAR_SUBQUERY_SQL: &str = r#"
3333
)
3434
"#;
3535

36+
const WINDOW_SCALAR_SUBQUERY_EXPECTED: &[&str] =
37+
&["+----+", "| rn |", "+----+", "| 1 |", "+----+"];
38+
3639
fn sqllogictest_style_ctx(push_down_filter_enabled: bool) -> SessionContext {
3740
let ctx =
3841
SessionContext::new_with_config(SessionConfig::new().with_target_partitions(4));
@@ -56,30 +59,20 @@ async fn capture_window_scalar_subquery_plans(
5659
))
5760
}
5861

59-
#[tokio::test]
60-
async fn window_scalar_subquery_regression() -> Result<()> {
61-
let ctx = SessionContext::new();
62+
async fn assert_window_scalar_subquery(ctx: SessionContext) -> Result<()> {
6263
let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?;
63-
64-
assert_batches_eq!(
65-
&["+----+", "| rn |", "+----+", "| 1 |", "+----+",],
66-
&results
67-
);
68-
64+
assert_batches_eq!(WINDOW_SCALAR_SUBQUERY_EXPECTED, &results);
6965
Ok(())
7066
}
7167

7268
#[tokio::test]
73-
async fn window_scalar_subquery_sqllogictest_style_regression() -> Result<()> {
74-
let ctx = sqllogictest_style_ctx(true);
75-
let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?;
76-
77-
assert_batches_eq!(
78-
&["+----+", "| rn |", "+----+", "| 1 |", "+----+",],
79-
&results
80-
);
69+
async fn window_scalar_subquery_regression() -> Result<()> {
70+
assert_window_scalar_subquery(SessionContext::new()).await
71+
}
8172

82-
Ok(())
73+
#[tokio::test]
74+
async fn window_scalar_subquery_sqllogictest_style_regression() -> Result<()> {
75+
assert_window_scalar_subquery(sqllogictest_style_ctx(true)).await
8376
}
8477

8578
#[tokio::test]
@@ -212,28 +205,18 @@ async fn window_scalar_subquery_optimizer_delta() -> Result<()> {
212205
let (disabled_optimized, disabled_physical) =
213206
capture_window_scalar_subquery_plans(false).await?;
214207

208+
assert_eq!(enabled_optimized, disabled_optimized);
209+
assert_eq!(enabled_physical, disabled_physical);
210+
215211
assert!(
216212
enabled_optimized
217213
.contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)")
218214
);
219215
assert!(enabled_optimized.contains("Cross Join:"));
220-
assert!(
221-
disabled_optimized
222-
.contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)")
223-
);
224-
assert!(disabled_optimized.contains("Cross Join:"));
225-
226216
assert!(
227217
enabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2")
228218
);
229219
assert!(enabled_physical.contains("CrossJoinExec"));
230-
assert!(
231-
disabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2")
232-
);
233-
assert!(disabled_physical.contains("CrossJoinExec"));
234-
235-
assert_eq!(enabled_optimized, disabled_optimized);
236-
assert_eq!(enabled_physical, disabled_physical);
237220

238221
Ok(())
239222
}

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 49 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,6 +1522,53 @@ mod tests {
15221522

15231523
use super::*;
15241524

1525+
fn scalar_subquery_right_plan() -> Result<LogicalPlan> {
1526+
LogicalPlanBuilder::from(test_table_scan_with_name("test1")?)
1527+
.project(vec![col("a").alias("acctbal")])?
1528+
.aggregate(
1529+
Vec::<Expr>::new(),
1530+
vec![avg(col("acctbal")).alias("avg_acctbal")],
1531+
)?
1532+
.alias("__scalar_sq_1")?
1533+
.build()
1534+
}
1535+
1536+
fn row_number_window_expr() -> Expr {
1537+
Expr::from(WindowFunction::new(
1538+
WindowFunctionDefinition::WindowUDF(
1539+
datafusion_functions_window::row_number::row_number_udwf(),
1540+
),
1541+
vec![],
1542+
))
1543+
.partition_by(vec![col("s.nation")])
1544+
.order_by(vec![col("s.acctbal").sort(false, true)])
1545+
.build()
1546+
.unwrap()
1547+
}
1548+
1549+
fn window_over_scalar_subquery_cross_join_plan(
1550+
with_project_wrapper: bool,
1551+
) -> Result<LogicalPlan> {
1552+
let left = {
1553+
let builder = LogicalPlanBuilder::from(test_table_scan()?)
1554+
.project(vec![col("a").alias("nation"), col("b").alias("acctbal")])?
1555+
.alias("s")?;
1556+
let builder = if with_project_wrapper {
1557+
builder.project(vec![col("s.nation"), col("s.acctbal")])?
1558+
} else {
1559+
builder
1560+
};
1561+
builder.build()?
1562+
};
1563+
1564+
LogicalPlanBuilder::from(left)
1565+
.cross_join(scalar_subquery_right_plan()?)?
1566+
.filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))?
1567+
.project(vec![col("s.nation"), col("s.acctbal")])?
1568+
.window(vec![row_number_window_expr()])?
1569+
.build()
1570+
}
1571+
15251572
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
15261573

15271574
macro_rules! assert_optimized_plan_equal {
@@ -2443,36 +2490,7 @@ mod tests {
24432490

24442491
#[test]
24452492
fn window_over_scalar_subquery_cross_join_keeps_filter_above_join() -> Result<()> {
2446-
let left = LogicalPlanBuilder::from(test_table_scan()?)
2447-
.project(vec![col("a").alias("nation"), col("b").alias("acctbal")])?
2448-
.alias("s")?
2449-
.build()?;
2450-
let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?)
2451-
.project(vec![col("a").alias("acctbal")])?
2452-
.aggregate(
2453-
Vec::<Expr>::new(),
2454-
vec![avg(col("acctbal")).alias("avg_acctbal")],
2455-
)?
2456-
.alias("__scalar_sq_1")?
2457-
.build()?;
2458-
2459-
let window = Expr::from(WindowFunction::new(
2460-
WindowFunctionDefinition::WindowUDF(
2461-
datafusion_functions_window::row_number::row_number_udwf(),
2462-
),
2463-
vec![],
2464-
))
2465-
.partition_by(vec![col("s.nation")])
2466-
.order_by(vec![col("s.acctbal").sort(false, true)])
2467-
.build()
2468-
.unwrap();
2469-
2470-
let plan = LogicalPlanBuilder::from(left)
2471-
.cross_join(right)?
2472-
.filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))?
2473-
.project(vec![col("s.nation"), col("s.acctbal")])?
2474-
.window(vec![window])?
2475-
.build()?;
2493+
let plan = window_over_scalar_subquery_cross_join_plan(false)?;
24762494

24772495
assert_optimized_plan_equal!(
24782496
plan,
@@ -2495,37 +2513,7 @@ mod tests {
24952513
#[test]
24962514
fn window_over_scalar_subquery_cross_join_with_project_wrapper_keeps_filter_above_join()
24972515
-> Result<()> {
2498-
let left = LogicalPlanBuilder::from(test_table_scan()?)
2499-
.project(vec![col("a").alias("nation"), col("b").alias("acctbal")])?
2500-
.alias("s")?
2501-
.project(vec![col("s.nation"), col("s.acctbal")])?
2502-
.build()?;
2503-
let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?)
2504-
.project(vec![col("a").alias("acctbal")])?
2505-
.aggregate(
2506-
Vec::<Expr>::new(),
2507-
vec![avg(col("acctbal")).alias("avg_acctbal")],
2508-
)?
2509-
.alias("__scalar_sq_1")?
2510-
.build()?;
2511-
2512-
let window = Expr::from(WindowFunction::new(
2513-
WindowFunctionDefinition::WindowUDF(
2514-
datafusion_functions_window::row_number::row_number_udwf(),
2515-
),
2516-
vec![],
2517-
))
2518-
.partition_by(vec![col("s.nation")])
2519-
.order_by(vec![col("s.acctbal").sort(false, true)])
2520-
.build()
2521-
.unwrap();
2522-
2523-
let plan = LogicalPlanBuilder::from(left)
2524-
.cross_join(right)?
2525-
.filter(col("s.acctbal").gt(col("__scalar_sq_1.avg_acctbal")))?
2526-
.project(vec![col("s.nation"), col("s.acctbal")])?
2527-
.window(vec![window])?
2528-
.build()?;
2516+
let plan = window_over_scalar_subquery_cross_join_plan(true)?;
25292517

25302518
assert_optimized_plan_equal!(
25312519
plan,

datafusion/optimizer/src/utils.rs

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,23 @@ mod tests {
250250
Operator, binary_expr, case, col, in_list, is_null, lit, when,
251251
};
252252

253+
fn restrict_null_predicate_in_modes(
254+
predicate: Expr,
255+
join_cols: &[Column],
256+
) -> Result<(bool, bool)> {
257+
let auto_result = with_null_restriction_eval_mode_for_test(
258+
NullRestrictionEvalMode::Auto,
259+
|| is_restrict_null_predicate(predicate.clone(), join_cols.iter()),
260+
)?;
261+
262+
let authoritative_result = with_null_restriction_eval_mode_for_test(
263+
NullRestrictionEvalMode::AuthoritativeOnly,
264+
|| is_restrict_null_predicate(predicate.clone(), join_cols.iter()),
265+
)?;
266+
267+
Ok((auto_result, authoritative_result))
268+
}
269+
253270
#[test]
254271
fn expr_is_restrict_null_predicate() -> Result<()> {
255272
let test_cases = vec![
@@ -465,27 +482,13 @@ mod tests {
465482
#[test]
466483
fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> {
467484
let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64));
468-
let join_cols_of_predicate = predicate.column_refs();
469-
470-
let auto_result = with_null_restriction_eval_mode_for_test(
471-
NullRestrictionEvalMode::Auto,
472-
|| {
473-
is_restrict_null_predicate(
474-
predicate.clone(),
475-
join_cols_of_predicate.iter().copied(),
476-
)
477-
},
478-
)?;
479-
480-
let authoritative_result = with_null_restriction_eval_mode_for_test(
481-
NullRestrictionEvalMode::AuthoritativeOnly,
482-
|| {
483-
is_restrict_null_predicate(
484-
predicate.clone(),
485-
join_cols_of_predicate.iter().copied(),
486-
)
487-
},
488-
)?;
485+
let join_cols_of_predicate = predicate
486+
.column_refs()
487+
.into_iter()
488+
.cloned()
489+
.collect::<Vec<_>>();
490+
let (auto_result, authoritative_result) =
491+
restrict_null_predicate_in_modes(predicate, &join_cols_of_predicate)?;
489492

490493
assert_eq!(auto_result, authoritative_result);
491494

@@ -496,17 +499,9 @@ mod tests {
496499
fn mixed_reference_predicate_remains_fast_pathed_in_authoritative_mode() -> Result<()>
497500
{
498501
let predicate = binary_expr(col("a"), Operator::Gt, col("b"));
499-
let column_a = Column::from_name("a");
500-
501-
let auto_result = with_null_restriction_eval_mode_for_test(
502-
NullRestrictionEvalMode::Auto,
503-
|| is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)),
504-
)?;
505-
506-
let authoritative_only_result = with_null_restriction_eval_mode_for_test(
507-
NullRestrictionEvalMode::AuthoritativeOnly,
508-
|| is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a)),
509-
)?;
502+
let join_cols = vec![Column::from_name("a")];
503+
let (auto_result, authoritative_only_result) =
504+
restrict_null_predicate_in_modes(predicate.clone(), &join_cols)?;
510505

511506
assert!(!auto_result, "{predicate}");
512507
assert!(!authoritative_only_result, "{predicate}");

0 commit comments

Comments
 (0)