Skip to content

Commit a40a555

Browse files
committed
support semi joins
1 parent 4a92804 commit a40a555

3 files changed

Lines changed: 589 additions & 29 deletions

File tree

datafusion/optimizer/src/reorder_join/cost.rs

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ use datafusion_expr::{Expr, JoinType, LogicalPlan};
2020

2121
use super::join_graph::Edge;
2222

23+
/// Fraction of preserved-side rows estimated to survive a semi/anti join
24+
/// when column NDV statistics are unavailable. Mirrors DuckDB's
25+
/// `CardinalityEstimator::DEFAULT_SEMI_ANTI_SELECTIVITY = 1/5`.
26+
const DEFAULT_SEMI_ANTI_SELECTIVITY: f64 = 0.2;
27+
2328
pub trait JoinCostEstimator: std::fmt::Debug {
2429
/// Cardinality of `plan`.
2530
///
@@ -32,14 +37,27 @@ pub trait JoinCostEstimator: std::fmt::Debug {
3237

3338
/// Estimated selectivity of joining `left` with `right` via `edge`.
3439
///
35-
/// Default: `1 / max(NDV(left.key), NDV(right.key))` for inner equi-joins
36-
/// when both NDVs are available; otherwise a per-join-type constant.
40+
/// Default: `1 / max(NDV(left.key), NDV(right.key))` for equi-joins
41+
/// (inner and semi/anti) when both NDVs are available; otherwise a
42+
/// per-join-type constant.
3743
fn selectivity(&self, edge: &Edge, left: &LogicalPlan, right: &LogicalPlan) -> f64 {
3844
let fallback = match edge.join_type {
3945
JoinType::Inner => 0.1,
46+
JoinType::LeftSemi
47+
| JoinType::LeftAnti
48+
| JoinType::RightSemi
49+
| JoinType::RightAnti => DEFAULT_SEMI_ANTI_SELECTIVITY,
4050
_ => 1.0,
4151
};
42-
if edge.join_type != JoinType::Inner || edge.on.is_empty() {
52+
let is_eq_join = matches!(
53+
edge.join_type,
54+
JoinType::Inner
55+
| JoinType::LeftSemi
56+
| JoinType::LeftAnti
57+
| JoinType::RightSemi
58+
| JoinType::RightAnti
59+
);
60+
if !is_eq_join || edge.on.is_empty() {
4361
return fallback;
4462
}
4563
// Use only the first equi-pair. Compounding pairwise selectivities
@@ -51,8 +69,26 @@ pub trait JoinCostEstimator: std::fmt::Debug {
5169
};
5270
let ndv_a = ndv_for(self, col_a, left, right);
5371
let ndv_b = ndv_for(self, col_b, left, right);
54-
match (ndv_a, ndv_b) {
55-
(Some(a), Some(b)) if a.max(b) > 0.0 => 1.0 / a.max(b),
72+
match edge.join_type {
73+
JoinType::Inner => match (ndv_a, ndv_b) {
74+
(Some(a), Some(b)) if a.max(b) > 0.0 => 1.0 / a.max(b),
75+
_ => fallback,
76+
},
77+
// Semi/anti containment estimator: surviving fraction of the
78+
// preserved side ≈ `min(NDV_preserved, NDV_filtering) / NDV_preserved`.
79+
// Edges normalized by `flatten_joins_recursive` always have
80+
// `on = (preserved_key, filtering_key)`, so the preserved
81+
// NDV is `ndv_a` for Left{Semi,Anti}. RightSemi/RightAnti
82+
// shouldn't appear in graph edges (they get normalized) but
83+
// are handled defensively.
84+
JoinType::LeftSemi | JoinType::LeftAnti => match (ndv_a, ndv_b) {
85+
(Some(a), Some(b)) if a > 0.0 => (a.min(b) / a).min(1.0),
86+
_ => fallback,
87+
},
88+
JoinType::RightSemi | JoinType::RightAnti => match (ndv_a, ndv_b) {
89+
(Some(a), Some(b)) if b > 0.0 => (a.min(b) / b).min(1.0),
90+
_ => fallback,
91+
},
5692
_ => fallback,
5793
}
5894
}
@@ -172,6 +208,31 @@ fn estimate_cardinality(plan: &LogicalPlan, column: Option<&Column>) -> Result<f
172208
}
173209
}
174210
}
211+
// Semi/anti joins do not grow rows: the output cardinality is
212+
// bounded by the preserved side. We size them via the
213+
// `DEFAULT_SEMI_ANTI_SELECTIVITY` heuristic. NDV queries on the
214+
// output route to whichever side is preserved.
215+
LogicalPlan::Join(j)
216+
if matches!(
217+
j.join_type,
218+
JoinType::LeftSemi
219+
| JoinType::LeftAnti
220+
| JoinType::RightSemi
221+
| JoinType::RightAnti
222+
) =>
223+
{
224+
let preserved = match j.join_type {
225+
JoinType::LeftSemi | JoinType::LeftAnti => &j.left,
226+
_ => &j.right,
227+
};
228+
match column {
229+
None => {
230+
let rows = estimate_cardinality(preserved, None)?;
231+
Ok(rows * DEFAULT_SEMI_ANTI_SELECTIVITY)
232+
}
233+
Some(c) => estimate_cardinality(preserved, Some(c)),
234+
}
235+
}
175236
x => {
176237
let inputs = x.inputs();
177238
if inputs.len() == 1 {

0 commit comments

Comments
 (0)