Skip to content

Commit 043c5db

Browse files
committed
fixes
1 parent 7bfb602 commit 043c5db

2 files changed

Lines changed: 23 additions & 108 deletions

File tree

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ use datafusion_expr::logical_plan::{
3737
};
3838
use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned};
3939
use datafusion_expr::{
40-
BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, exists,
41-
in_subquery, lit, not, not_exists, not_in_subquery,
40+
Aggregate, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator,
41+
exists, in_subquery, lit, not, not_exists, not_in_subquery,
4242
};
4343

4444
use log::debug;
@@ -198,10 +198,9 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
198198
}
199199
}
200200

201-
let new_plan = LogicalPlanBuilder::from(cur_input)
202-
.aggregate(new_group_exprs, new_aggr_exprs)?
203-
.build()?;
204-
return Ok(Transformed::yes(new_plan));
201+
let new_agg =
202+
Aggregate::try_new(Arc::new(cur_input), new_group_exprs, new_aggr_exprs)?;
203+
return Ok(Transformed::yes(LogicalPlan::Aggregate(new_agg)));
205204
}
206205

207206
// Handle Projection nodes with subqueries in expressions
@@ -678,45 +677,6 @@ mod tests {
678677
))
679678
}
680679

681-
/// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate
682-
#[test]
683-
fn aggregate_case_in_subquery() -> Result<()> {
684-
let table_scan = test_table_scan_with_name("distinct_source")?;
685-
use datafusion_expr::expr_fn::when;
686-
use datafusion_functions_aggregate::expr_fn::max as agg_max;
687-
688-
let agg_b: Expr = agg_max(col("distinct_source.b"));
689-
let subq = LogicalPlanBuilder::from(table_scan.clone())
690-
.aggregate(Vec::<Expr>::new(), vec![agg_b])?
691-
.project(vec![col("max(distinct_source.b)")])?
692-
.build()?;
693-
694-
let case_expr = when(
695-
in_subquery(col("distinct_source.b"), Arc::new(subq)),
696-
lit(1),
697-
)
698-
.otherwise(lit(0))?;
699-
700-
let plan = LogicalPlanBuilder::from(table_scan)
701-
.aggregate(
702-
vec![col("distinct_source.a").alias("primary_key")],
703-
vec![
704-
agg_max(case_expr).alias("is_in_most_recent_task"),
705-
agg_max(col("distinct_source.c")).alias("max_timestamp"),
706-
],
707-
)?
708-
.build()?;
709-
710-
use crate::{OptimizerContext, OptimizerRule};
711-
let optimized = DecorrelatePredicateSubquery::new()
712-
.rewrite(plan, &OptimizerContext::new())?
713-
.data;
714-
let lp = optimized.display_indent().to_string();
715-
assert!(lp.contains("Aggregate:"));
716-
assert!(lp.contains("Left"));
717-
Ok(())
718-
}
719-
720680
/// Test for several IN subquery expressions
721681
#[test]
722682
fn in_subquery_multiple() -> Result<()> {
@@ -834,9 +794,10 @@ mod tests {
834794
LeftMark Join: Filter: a = __correlated_sq_1.ua [a:Int32;N, mark:Boolean]
835795
Projection: column1 AS a [a:Int32;N]
836796
Values: (Int32(1)), (Int32(2)) [column1:Int32;N]
837-
SubqueryAlias: __correlated_sq_1 [ua:Int32;N]
838-
Projection: column1 AS ua [ua:Int32;N]
839-
Values: (Int32(2)) [column1:Int32;N]
797+
Projection: __correlated_sq_1.ua [ua:Int32;N]
798+
SubqueryAlias: __correlated_sq_1 [ua:Int32;N]
799+
Projection: column1 AS ua [ua:Int32;N]
800+
Values: (Int32(2)) [column1:Int32;N]
840801
"
841802
)
842803
}
@@ -1924,14 +1885,13 @@ mod tests {
19241885
plan,
19251886
@r"
19261887
Projection: customer.c_custkey [c_custkey:Int64]
1927-
Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1928-
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1929-
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1930-
TableScan: customer [c_custkey:Int64, c_name:Utf8]
1931-
SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1932-
Projection: orders.o_custkey [o_custkey:Int64]
1933-
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1934-
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1888+
Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1889+
LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]
1890+
TableScan: customer [c_custkey:Int64, c_name:Utf8]
1891+
SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]
1892+
Projection: orders.o_custkey [o_custkey:Int64]
1893+
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1894+
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
19351895
"
19361896
)
19371897
}

datafusion/optimizer/src/optimize_projections/mod.rs

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,8 @@ fn optimize_projections(
402402
}
403403
LogicalPlan::Join(join) => {
404404
let left_len = join.left.schema().fields().len();
405-
let right_len = join.right.schema().fields().len();
406405
let (left_req_indices, right_req_indices) =
407-
split_join_requirements(left_len, right_len, indices, &join.join_type);
406+
split_join_requirements(left_len, indices, &join.join_type);
408407
let left_indices =
409408
left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
410409
let right_indices =
@@ -729,29 +728,21 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
729728
/// adjusted based on the join type.
730729
fn split_join_requirements(
731730
left_len: usize,
732-
right_len: usize,
733731
indices: RequiredIndices,
734732
join_type: &JoinType,
735733
) -> (RequiredIndices, RequiredIndices) {
736734
match join_type {
737735
// In these cases requirements are split between left/right children:
738-
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
736+
JoinType::Inner
737+
| JoinType::Left
738+
| JoinType::Right
739+
| JoinType::Full
740+
| JoinType::LeftMark
741+
| JoinType::RightMark => {
739742
// Decrease right side indices by `left_len` so that they point to valid
740743
// positions within the right child:
741744
indices.split_off(left_len)
742745
}
743-
JoinType::LeftMark => {
744-
// LeftMark output: [left_cols(0..left_len), mark]
745-
// The mark column is synthetic (produced by the join itself),
746-
// so discard it and route only to the left child.
747-
let (left_indices, _mark) = indices.split_off(left_len);
748-
(left_indices, RequiredIndices::new())
749-
}
750-
JoinType::RightMark => {
751-
// Same as LeftMark, but for the right child.
752-
let (right_indices, _mark) = indices.split_off(right_len);
753-
(RequiredIndices::new(), right_indices)
754-
}
755746
// All requirements can be re-routed to left child directly.
756747
JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()),
757748
// All requirements can be re-routed to right side directly.
@@ -2366,42 +2357,6 @@ mod tests {
23662357
}
23672358

23682359
// Regression test for https://github.com/apache/datafusion/issues/20083
2369-
// Optimizer must not fail when LeftMark joins from EXISTS OR EXISTS
2370-
// feed into a Left join.
2371-
#[test]
2372-
fn optimize_projections_exists_or_exists_with_outer_join() -> Result<()> {
2373-
use datafusion_expr::utils::disjunction;
2374-
use datafusion_expr::{exists, out_ref_col};
2375-
2376-
let table_a = test_table_scan_with_name("a")?;
2377-
let table_b = test_table_scan_with_name("b")?;
2378-
2379-
let sq_a = Arc::new(
2380-
LogicalPlanBuilder::from(test_table_scan_with_name("sq_a")?)
2381-
.filter(col("sq_a.a").eq(out_ref_col(DataType::UInt32, "a.a")))?
2382-
.project(vec![lit(1)])?
2383-
.build()?,
2384-
);
2385-
2386-
let sq_b = Arc::new(
2387-
LogicalPlanBuilder::from(test_table_scan_with_name("sq_b")?)
2388-
.filter(col("sq_b.b").eq(out_ref_col(DataType::UInt32, "a.b")))?
2389-
.project(vec![lit(1)])?
2390-
.build()?,
2391-
);
2392-
2393-
let plan = LogicalPlanBuilder::from(table_a)
2394-
.filter(disjunction(vec![exists(sq_a), exists(sq_b)]).unwrap())?
2395-
.join(table_b, JoinType::Left, (vec!["a"], vec!["a"]), None)?
2396-
.build()?;
2397-
2398-
let optimizer = Optimizer::new();
2399-
let config = OptimizerContext::new();
2400-
optimizer.optimize(plan, &config, observe)?;
2401-
2402-
Ok(())
2403-
}
2404-
24052360
#[test]
24062361
fn optimize_projections_left_mark_join_with_projection() -> Result<()> {
24072362
let table_a = test_table_scan_with_name("a")?;

0 commit comments

Comments
 (0)