Skip to content

Commit b8998c7

Browse files
perf: Convert inner joins to semi joins when equivalent (#22652)
## Which issue does this PR close? - Closes #22594 ## Rationale for this change This PR extends the `EliminateJoin` rewrite pass to replace inner joins with semi joins in some cases. An inner join `L ⋈ R` can be rewritten to a left semi join `L ⋉ R` if two conditions hold: 1. None of R's columns are referenced above the join 2. (a) each L row matches at most one R row, OR (b) the consumers of the join result are insensitive to duplicates (And symmetrically with right semi joins.) ## What changes are included in this PR? * Add `for_each_referenced_index` helper that is used by both `EliminateJoin` and `EliminateProjections` * Introduce `LiveColumns` type to track the "live" (referenced by parent) columns of a plan node * Add inner -> semi join rewrite to `EliminateJoin` * Add unit and SLT tests for rewrite behavior * Update SLT test fixtures for plan changes ## Are these changes tested? Yes; new tests added. ## Are there any user-facing changes? Some plan changes but no behavioral changes. --------- Co-authored-by: Daniël Heres <danielheres@gmail.com>
1 parent 3b321a2 commit b8998c7

14 files changed

Lines changed: 1294 additions & 131 deletions

File tree

datafusion/optimizer/src/eliminate_join.rs

Lines changed: 1125 additions & 28 deletions
Large diffs are not rendered by default.

datafusion/optimizer/src/optimize_projections/required_indices.rs

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
//! [`RequiredIndices`] helper for OptimizeProjection
1919
20-
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
20+
use crate::utils::for_each_referenced_index;
21+
use datafusion_common::tree_node::TreeNodeRecursion;
2122
use datafusion_common::{Column, DFSchemaRef, Result};
2223
use datafusion_expr::{Expr, LogicalPlan};
2324

@@ -112,29 +113,8 @@ impl RequiredIndices {
112113
/// * `input_schema`: The input schema to analyze for index requirements.
113114
/// * `expr`: An expression for which we want to find necessary field indices.
114115
fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) {
115-
// `apply` does not descend into subqueries, so recurse manually to
116-
// handle those cases.
117-
expr.apply(|e| {
118-
match e {
119-
Expr::Column(c) | Expr::OuterReferenceColumn(_, c) => {
120-
if let Some(idx) = input_schema.maybe_index_of_column(c) {
121-
self.indices.push(idx);
122-
}
123-
}
124-
Expr::ScalarSubquery(sub) => {
125-
self.add_exprs(input_schema, &sub.outer_ref_columns);
126-
}
127-
Expr::Exists(ex) => {
128-
self.add_exprs(input_schema, &ex.subquery.outer_ref_columns);
129-
}
130-
Expr::InSubquery(isq) => {
131-
self.add_exprs(input_schema, &isq.subquery.outer_ref_columns);
132-
}
133-
_ => {}
134-
}
135-
Ok(TreeNodeRecursion::Continue)
136-
})
137-
.expect("traversal is infallible");
116+
for_each_referenced_index(expr, input_schema, |idx| self.indices.push(idx))
117+
.expect("traversal is infallible");
138118
}
139119

140120
/// Like [`Self::add_expr`], but for multiple expressions.

datafusion/optimizer/src/utils.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ use arrow::array::{Array, RecordBatch, new_null_array};
2424
use arrow::datatypes::{DataType, Field, Schema};
2525
use datafusion_common::TableReference;
2626
use datafusion_common::cast::as_boolean_array;
27-
use datafusion_common::tree_node::{TransformedResult, TreeNode};
27+
use datafusion_common::tree_node::{TransformedResult, TreeNode, TreeNodeRecursion};
2828
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
2929
use datafusion_expr::execution_props::ExecutionProps;
30+
use datafusion_expr::expr::{Exists, InSubquery, SetComparison};
3031
use datafusion_expr::expr_rewriter::replace_col;
3132
use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan};
3233
use datafusion_physical_expr::create_physical_expr;
@@ -37,6 +38,56 @@ use std::sync::Arc;
3738
/// as it was initially placed here and then moved elsewhere.
3839
pub use datafusion_expr::expr_rewriter::NamePreserver;
3940

41+
/// Invokes `f` with the index, within `schema`, of every column referenced by
42+
/// `expr` — including columns reached through a correlated subquery's outer
43+
/// references. Columns absent from `schema` are skipped.
44+
///
45+
/// A subquery's own plan is intentionally not traversed: its internal columns
46+
/// index into its own schema, not `schema`; only the outer (correlated) columns
47+
/// it references from `schema` are relevant. The comparison expression of an
48+
/// `IN`/set-comparison subquery is reached by the normal expression walk.
49+
///
50+
/// This is the shared primitive behind the top-down "which of a node's output
51+
/// columns does an ancestor still need" analyses, namely
52+
/// [`OptimizeProjections`](crate::optimize_projections::OptimizeProjections)
53+
/// and [`EliminateJoin`](crate::eliminate_join::EliminateJoin). The two keep
54+
/// their own required-index containers (an ordered set vs. a hash set), so this
55+
/// reports indices through a callback rather than populating a shared type.
56+
pub(crate) fn for_each_referenced_index(
57+
expr: &Expr,
58+
schema: &DFSchema,
59+
mut f: impl FnMut(usize),
60+
) -> Result<()> {
61+
visit_referenced_indices(expr, schema, &mut f)
62+
}
63+
64+
fn visit_referenced_indices(
65+
expr: &Expr,
66+
schema: &DFSchema,
67+
f: &mut dyn FnMut(usize),
68+
) -> Result<()> {
69+
expr.apply(|expr| {
70+
match expr {
71+
Expr::Column(column) | Expr::OuterReferenceColumn(_, column) => {
72+
if let Some(idx) = schema.maybe_index_of_column(column) {
73+
f(idx);
74+
}
75+
}
76+
Expr::Exists(Exists { subquery, .. })
77+
| Expr::InSubquery(InSubquery { subquery, .. })
78+
| Expr::SetComparison(SetComparison { subquery, .. })
79+
| Expr::ScalarSubquery(subquery) => {
80+
for outer in &subquery.outer_ref_columns {
81+
visit_referenced_indices(outer, schema, f)?;
82+
}
83+
}
84+
_ => {}
85+
}
86+
Ok(TreeNodeRecursion::Continue)
87+
})?;
88+
Ok(())
89+
}
90+
4091
/// Returns true if `expr` contains all columns in `schema_cols`
4192
pub(crate) fn has_all_column_refs(
4293
expr: &Expr,

datafusion/sqllogictest/test_files/joins.slt

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,19 +1333,57 @@ inner join join_t2 on join_t1.t1_id = join_t2.t2_id
13331333
----
13341334
logical_plan
13351335
01)Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[]]
1336-
02)--Projection: join_t1.t1_id
1337-
03)----Inner Join: join_t1.t1_id = join_t2.t2_id
1338-
04)------TableScan: join_t1 projection=[t1_id]
1339-
05)------TableScan: join_t2 projection=[t2_id]
1336+
02)--LeftSemi Join: join_t1.t1_id = join_t2.t2_id
1337+
03)----TableScan: join_t1 projection=[t1_id]
1338+
04)----TableScan: join_t2 projection=[t2_id]
13401339
physical_plan
13411340
01)AggregateExec: mode=FinalPartitioned, gby=[t1_id@0 as t1_id], aggr=[]
13421341
02)--RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
13431342
03)----AggregateExec: mode=Partial, gby=[t1_id@0 as t1_id], aggr=[]
1344-
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
1343+
04)------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)]
13451344
05)--------DataSourceExec: partitions=1, partition_sizes=[1]
13461345
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
13471346
07)----------DataSourceExec: partitions=1, partition_sizes=[1]
13481347

1348+
statement ok
1349+
set datafusion.explain.logical_plan_only = true;
1350+
1351+
# A single `count(DISTINCT col)` over a join whose other side is used only as an
1352+
# existence filter can be rewritten to a semi join.
1353+
query TT
1354+
EXPLAIN
1355+
select join_t1.t1_id, count(distinct join_t1.t1_int)
1356+
from join_t1
1357+
inner join join_t2 on join_t1.t1_id = join_t2.t2_id
1358+
group by join_t1.t1_id
1359+
----
1360+
logical_plan
1361+
01)Projection: join_t1.t1_id, count(alias1) AS count(DISTINCT join_t1.t1_int)
1362+
02)--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(alias1)]]
1363+
03)----Aggregate: groupBy=[[join_t1.t1_id, join_t1.t1_int AS alias1]], aggr=[[]]
1364+
04)------LeftSemi Join: join_t1.t1_id = join_t2.t2_id
1365+
05)--------TableScan: join_t1 projection=[t1_id, t1_int]
1366+
06)--------TableScan: join_t2 projection=[t2_id]
1367+
1368+
# A similar query with two DISTINCT aggregates is currently not rewritten
1369+
# TODO: https://github.com/apache/datafusion/issues/22644
1370+
query TT
1371+
EXPLAIN
1372+
select join_t1.t1_id, count(distinct join_t1.t1_int), count(distinct join_t1.t1_name)
1373+
from join_t1
1374+
inner join join_t2 on join_t1.t1_id = join_t2.t2_id
1375+
group by join_t1.t1_id
1376+
----
1377+
logical_plan
1378+
01)Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(DISTINCT join_t1.t1_int), count(DISTINCT join_t1.t1_name)]]
1379+
02)--Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int
1380+
03)----Inner Join: join_t1.t1_id = join_t2.t2_id
1381+
04)------TableScan: join_t1 projection=[t1_id, t1_name, t1_int]
1382+
05)------TableScan: join_t2 projection=[t2_id]
1383+
1384+
statement ok
1385+
set datafusion.explain.logical_plan_only = false;
1386+
13491387
# Join on struct
13501388
query TT
13511389
explain select join_t3.s3, join_t4.s4
@@ -1411,10 +1449,9 @@ logical_plan
14111449
01)Projection: count(alias1) AS count(DISTINCT join_t1.t1_id)
14121450
02)--Aggregate: groupBy=[[]], aggr=[[count(alias1)]]
14131451
03)----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]]
1414-
04)------Projection: join_t1.t1_id
1415-
05)--------Inner Join: join_t1.t1_id = join_t2.t2_id
1416-
06)----------TableScan: join_t1 projection=[t1_id]
1417-
07)----------TableScan: join_t2 projection=[t2_id]
1452+
04)------LeftSemi Join: join_t1.t1_id = join_t2.t2_id
1453+
05)--------TableScan: join_t1 projection=[t1_id]
1454+
06)--------TableScan: join_t2 projection=[t2_id]
14181455
physical_plan
14191456
01)ProjectionExec: expr=[count(alias1)@0 as count(DISTINCT join_t1.t1_id)]
14201457
02)--AggregateExec: mode=Final, gby=[], aggr=[count(alias1)]
@@ -1423,7 +1460,7 @@ physical_plan
14231460
05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]
14241461
06)----------RepartitionExec: partitioning=Hash([alias1@0], 2), input_partitions=2
14251462
07)------------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]
1426-
08)--------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
1463+
08)--------------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)]
14271464
09)----------------DataSourceExec: partitions=1, partition_sizes=[1]
14281465
10)----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
14291466
11)------------------DataSourceExec: partitions=1, partition_sizes=[1]

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,13 +338,13 @@ where c_acctbal < (
338338
logical_plan
339339
01)Sort: customer.c_custkey ASC NULLS LAST
340340
02)--Projection: customer.c_custkey
341-
03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice)
341+
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice)
342342
04)------TableScan: customer projection=[c_custkey, c_acctbal]
343343
05)------SubqueryAlias: __scalar_sq_1
344344
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
345345
07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]]
346346
08)------------Projection: orders.o_custkey, orders.o_totalprice
347-
09)--------------Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price
347+
09)--------------LeftSemi Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price
348348
10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]
349349
11)----------------SubqueryAlias: __scalar_sq_2
350350
12)------------------Projection: sum(lineitem.l_extendedprice) AS price, lineitem.l_orderkey
@@ -555,7 +555,7 @@ logical_plan
555555
02)--TableScan: t0 projection=[t0_id, t0_name]
556556
03)--SubqueryAlias: __correlated_sq_2
557557
04)----Projection: t1.t1_name
558-
05)------Inner Join: t1.t1_id = t2.t2_id
558+
05)------LeftSemi Join: t1.t1_id = t2.t2_id
559559
06)--------TableScan: t1 projection=[t1_id, t1_name]
560560
07)--------TableScan: t2 projection=[t2_id]
561561

@@ -568,7 +568,7 @@ logical_plan
568568
02)--TableScan: t0 projection=[t0_id, t0_name]
569569
03)--SubqueryAlias: __correlated_sq_1
570570
04)----Projection: t2.t2_name
571-
05)------Inner Join: t1.t1_id = t2.t2_id
571+
05)------RightSemi Join: t1.t1_id = t2.t2_id
572572
06)--------TableScan: t1 projection=[t1_id]
573573
07)--------SubqueryAlias: t2
574574
08)----------TableScan: t2 projection=[t2_id, t2_name]
@@ -1675,7 +1675,7 @@ where c_acctbal < (
16751675
logical_plan
16761676
01)Sort: customer.c_custkey ASC NULLS LAST
16771677
02)--Projection: customer.c_custkey
1678-
03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
1678+
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
16791679
04)------TableScan: customer projection=[c_custkey, c_acctbal]
16801680
05)------SubqueryAlias: __scalar_sq_2
16811681
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
@@ -1701,7 +1701,7 @@ where c_acctbal < (
17011701
logical_plan
17021702
01)Sort: customer.c_custkey ASC NULLS LAST
17031703
02)--Projection: customer.c_custkey
1704-
03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
1704+
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
17051705
04)------TableScan: customer projection=[c_custkey, c_acctbal]
17061706
05)------SubqueryAlias: __scalar_sq_2
17071707
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
@@ -1746,7 +1746,7 @@ WHERE e1.salary > (
17461746
----
17471747
logical_plan
17481748
01)Projection: e1.employee_name, e1.salary
1749-
02)--Inner Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary)
1749+
02)--LeftSemi Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary)
17501750
03)----SubqueryAlias: e1
17511751
04)------TableScan: employees projection=[employee_name, dept_id, salary]
17521752
05)----SubqueryAlias: __scalar_sq_1

datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ logical_plan
5454
05)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15))
5555
06)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]]
5656
07)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost
57-
08)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey
57+
08)--------------LeftSemi Join: supplier.s_nationkey = nation.n_nationkey
5858
09)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
5959
10)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
6060
11)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost]
@@ -64,7 +64,7 @@ logical_plan
6464
15)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")]
6565
16)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]]
6666
17)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost
67-
18)----------Inner Join: supplier.s_nationkey = nation.n_nationkey
67+
18)----------LeftSemi Join: supplier.s_nationkey = nation.n_nationkey
6868
19)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
6969
20)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
7070
21)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
@@ -81,7 +81,7 @@ physical_plan
8181
06)----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
8282
07)------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4
8383
08)--------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
84-
09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2]
84+
09)----------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2]
8585
10)------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4
8686
11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5]
8787
12)----------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4
@@ -96,7 +96,7 @@ physical_plan
9696
21)----AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
9797
22)------CoalescePartitionsExec
9898
23)--------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
99-
24)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1]
99+
24)----------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1]
100100
25)------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4
101101
26)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4]
102102
27)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4

0 commit comments

Comments
 (0)