|
17 | 17 |
|
18 | 18 | //! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins. |
19 | 19 |
|
20 | | -use std::collections::BTreeSet; |
| 20 | +use std::sync::Arc; |
21 | 21 |
|
22 | | -use crate::decorrelate::PullUpCorrelatedExpr; |
| 22 | +use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; |
23 | 23 | use crate::optimizer::ApplyOrder; |
| 24 | +use crate::utils::evaluates_to_null; |
24 | 25 | use crate::{OptimizerConfig, OptimizerRule}; |
25 | | -use datafusion_expr::{Join, lit}; |
| 26 | +use datafusion_expr::{Expr, Join, expr}; |
26 | 27 |
|
27 | | -use datafusion_common::Result; |
28 | 28 | use datafusion_common::tree_node::{ |
29 | 29 | Transformed, TransformedResult, TreeNode, TreeNodeRecursion, |
30 | 30 | }; |
31 | | -use datafusion_expr::logical_plan::JoinType; |
| 31 | +use datafusion_common::{Column, DFSchema, Result, TableReference}; |
| 32 | +use datafusion_expr::logical_plan::{JoinType, Subquery}; |
32 | 33 | use datafusion_expr::utils::conjunction; |
33 | | -use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; |
| 34 | +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, SubqueryAlias}; |
34 | 35 |
|
35 | 36 | /// Optimizer rule for rewriting lateral joins to joins |
36 | 37 | #[derive(Default, Debug)] |
@@ -70,74 +71,206 @@ impl OptimizerRule for DecorrelateLateralJoin { |
70 | 71 | } |
71 | 72 | } |
72 | 73 |
|
73 | | -// Build the decorrelated join based on the original lateral join query. For now, we only support cross/inner |
74 | | -// lateral joins. |
| 74 | +// Build the decorrelated join based on the original lateral join query. For |
| 75 | +// now, we only support cross/inner lateral joins. |
75 | 76 | fn rewrite_internal(join: Join) -> Result<Transformed<LogicalPlan>> { |
| 77 | + // TODO: Support outer joins |
| 78 | + // <https://github.com/apache/datafusion/issues/21199> |
76 | 79 | if join.join_type != JoinType::Inner { |
77 | 80 | return Ok(Transformed::no(LogicalPlan::Join(join))); |
78 | 81 | } |
79 | 82 |
|
80 | | - match join.right.apply_with_subqueries(|p| { |
81 | | - // TODO: support outer joins |
82 | | - if p.contains_outer_reference() { |
83 | | - Ok(TreeNodeRecursion::Stop) |
84 | | - } else { |
85 | | - Ok(TreeNodeRecursion::Continue) |
86 | | - } |
87 | | - })? { |
88 | | - TreeNodeRecursion::Stop => {} |
89 | | - TreeNodeRecursion::Continue => { |
90 | | - // The left side contains outer references, we need to decorrelate it. |
91 | | - return Ok(Transformed::new( |
92 | | - LogicalPlan::Join(join), |
93 | | - false, |
94 | | - TreeNodeRecursion::Jump, |
95 | | - )); |
96 | | - } |
97 | | - TreeNodeRecursion::Jump => { |
98 | | - unreachable!("") |
99 | | - } |
100 | | - } |
101 | | - |
102 | | - let LogicalPlan::Subquery(subquery) = join.right.as_ref() else { |
| 83 | + // The right side is wrapped in a Subquery node when it contains outer |
| 84 | + // references. Quickly skip joins that don't have this structure. |
| 85 | + let Some((subquery, alias)) = extract_lateral_subquery(join.right.as_ref()) else { |
103 | 86 | return Ok(Transformed::no(LogicalPlan::Join(join))); |
104 | 87 | }; |
105 | 88 |
|
106 | | - if join.join_type != JoinType::Inner { |
| 89 | + // If the subquery has no outer references, there is nothing to decorrelate. |
| 90 | + // A LATERAL with no outer references is just a cross join. |
| 91 | + let has_outer_refs = matches!( |
| 92 | + subquery.subquery.apply_with_subqueries(|p| { |
| 93 | + if p.contains_outer_reference() { |
| 94 | + Ok(TreeNodeRecursion::Stop) |
| 95 | + } else { |
| 96 | + Ok(TreeNodeRecursion::Continue) |
| 97 | + } |
| 98 | + })?, |
| 99 | + TreeNodeRecursion::Stop |
| 100 | + ); |
| 101 | + if !has_outer_refs { |
107 | 102 | return Ok(Transformed::no(LogicalPlan::Join(join))); |
108 | 103 | } |
| 104 | + |
109 | 105 | let subquery_plan = subquery.subquery.as_ref(); |
| 106 | + let original_join_filter = join.filter.clone(); |
| 107 | + |
| 108 | + // Walk the subquery plan bottom-up, extracting correlated filter |
| 109 | + // predicates into join conditions and converting scalar aggregates |
| 110 | + // into group-by aggregates keyed on the correlation columns. |
110 | 111 | let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); |
111 | 112 | let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?; |
112 | 113 | if !pull_up.can_pull_up { |
113 | 114 | return Ok(Transformed::no(LogicalPlan::Join(join))); |
114 | 115 | } |
115 | 116 |
|
116 | | - let mut all_correlated_cols = BTreeSet::new(); |
117 | | - pull_up |
118 | | - .correlated_subquery_cols_map |
119 | | - .values() |
120 | | - .for_each(|cols| all_correlated_cols.extend(cols.clone())); |
121 | | - let join_filter_opt = conjunction(pull_up.join_filters); |
122 | | - let join_filter = match join_filter_opt { |
123 | | - Some(join_filter) => join_filter, |
124 | | - None => lit(true), |
| 117 | + // TODO: support HAVING in lateral subqueries. |
| 118 | + // <https://github.com/apache/datafusion/issues/21198> |
| 119 | + if pull_up.pull_up_having_expr.is_some() { |
| 120 | + return Ok(Transformed::no(LogicalPlan::Join(join))); |
| 121 | + } |
| 122 | + |
| 123 | + // We apply the correlation predicates (extracted from the subquery's WHERE) |
| 124 | + // as the ON clause of the rewritten join. The original ON clause is applied |
| 125 | + // as a post-join predicate. Semantically, this is important when the join |
| 126 | + // is rewritten as a left join; we only want outer join semantics for the |
| 127 | + // correlation predicates (which is required for "count bug" handling), not |
| 128 | + // the original join predicates. |
| 129 | + let correlation_filter = conjunction(pull_up.join_filters); |
| 130 | + |
| 131 | + // Look up each aggregate's default value on empty input (e.g., COUNT → 0, |
| 132 | + // SUM → NULL). This must happen before wrapping in SubqueryAlias, because |
| 133 | + // the map is keyed by LogicalPlan and wrapping changes the plan. |
| 134 | + let collected_count_expr_map = pull_up |
| 135 | + .collected_count_expr_map |
| 136 | + .get(&rewritten_subquery) |
| 137 | + .cloned(); |
| 138 | + |
| 139 | + // Re-wrap in SubqueryAlias if the original had one, preserving the alias name. |
| 140 | + // The SubqueryAlias re-qualifies all columns with the alias, so we must also |
| 141 | + // rewrite column references in both the correlation and ON-clause filters. |
| 142 | + let (right_plan, correlation_filter, original_join_filter) = |
| 143 | + if let Some(ref alias) = alias { |
| 144 | + let inner_schema = Arc::clone(rewritten_subquery.schema()); |
| 145 | + let right = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( |
| 146 | + Arc::new(rewritten_subquery), |
| 147 | + alias.clone(), |
| 148 | + )?); |
| 149 | + let corr = correlation_filter |
| 150 | + .map(|f| requalify_filter(f, &inner_schema, alias)) |
| 151 | + .transpose()?; |
| 152 | + let on = original_join_filter |
| 153 | + .map(|f| requalify_filter(f, &inner_schema, alias)) |
| 154 | + .transpose()?; |
| 155 | + (right, corr, on) |
| 156 | + } else { |
| 157 | + (rewritten_subquery, correlation_filter, original_join_filter) |
| 158 | + }; |
| 159 | + |
| 160 | + // Use a left join when a scalar aggregation was pulled up (preserves |
| 161 | + // outer rows with no matches), otherwise keep inner join. |
| 162 | + // SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); → left join |
| 163 | + // SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); → inner join |
| 164 | + let join_type = if pull_up.pulled_up_scalar_agg { |
| 165 | + JoinType::Left |
| 166 | + } else { |
| 167 | + JoinType::Inner |
125 | 168 | }; |
126 | | - // -- inner join but the right side always has one row, we need to rewrite it to a left join |
127 | | - // SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); |
128 | | - // -- inner join but the right side number of rows is related to the filter (join) condition, so keep inner join. |
129 | | - // SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); |
| 169 | + let left_field_count = join.left.schema().fields().len(); |
130 | 170 | let new_plan = LogicalPlanBuilder::from(join.left) |
131 | | - .join_on( |
132 | | - rewritten_subquery, |
133 | | - if pull_up.pulled_up_scalar_agg { |
134 | | - JoinType::Left |
135 | | - } else { |
136 | | - JoinType::Inner |
137 | | - }, |
138 | | - Some(join_filter), |
139 | | - )? |
| 171 | + .join_on(right_plan, join_type, correlation_filter)? |
140 | 172 | .build()?; |
141 | | - // TODO: handle count(*) bug |
| 173 | + |
| 174 | + // Handle the count bug: after a left join, unmatched outer rows get NULLs |
| 175 | + // for all right-side columns. But COUNT(*) over an empty group should |
| 176 | + // return 0, not NULL. Add a projection that wraps affected expressions: |
| 177 | + // CASE WHEN __always_true IS NULL THEN <default> ELSE <column> END |
| 178 | + let new_plan = if let Some(expr_map) = collected_count_expr_map { |
| 179 | + let join_schema = new_plan.schema(); |
| 180 | + let alias_qualifier = alias.as_ref(); |
| 181 | + let mut proj_exprs: Vec<Expr> = vec![]; |
| 182 | + |
| 183 | + for (i, (qualifier, field)) in join_schema.iter().enumerate() { |
| 184 | + let col = Expr::Column(Column::new(qualifier.cloned(), field.name())); |
| 185 | + |
| 186 | + // Only compensate right-side (subquery) fields. Left-side fields |
| 187 | + // may share a name with an aggregate alias but must not be wrapped. |
| 188 | + let name = field.name(); |
| 189 | + if i >= left_field_count |
| 190 | + && let Some(default_value) = expr_map.get(name.as_str()) |
| 191 | + && !evaluates_to_null(default_value.clone(), default_value.column_refs())? |
| 192 | + { |
| 193 | + // Column whose aggregate doesn't naturally return NULL |
| 194 | + // on empty input (e.g., COUNT returns 0). Wrap it. |
| 195 | + let indicator_col = |
| 196 | + Column::new(alias_qualifier.cloned(), UN_MATCHED_ROW_INDICATOR); |
| 197 | + let case_expr = Expr::Case(expr::Case { |
| 198 | + expr: None, |
| 199 | + when_then_expr: vec![( |
| 200 | + Box::new(Expr::IsNull(Box::new(Expr::Column(indicator_col)))), |
| 201 | + Box::new(default_value.clone()), |
| 202 | + )], |
| 203 | + else_expr: Some(Box::new(col)), |
| 204 | + }); |
| 205 | + proj_exprs.push(Expr::Alias(expr::Alias { |
| 206 | + expr: Box::new(case_expr), |
| 207 | + relation: qualifier.cloned(), |
| 208 | + name: name.to_string(), |
| 209 | + metadata: None, |
| 210 | + })); |
| 211 | + continue; |
| 212 | + } |
| 213 | + proj_exprs.push(col); |
| 214 | + } |
| 215 | + |
| 216 | + LogicalPlanBuilder::from(new_plan) |
| 217 | + .project(proj_exprs)? |
| 218 | + .build()? |
| 219 | + } else { |
| 220 | + new_plan |
| 221 | + }; |
| 222 | + |
| 223 | + // Apply the original ON clause as a post-join filter. |
| 224 | + let new_plan = if let Some(on_filter) = original_join_filter { |
| 225 | + LogicalPlanBuilder::from(new_plan) |
| 226 | + .filter(on_filter)? |
| 227 | + .build()? |
| 228 | + } else { |
| 229 | + new_plan |
| 230 | + }; |
| 231 | + |
142 | 232 | Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump)) |
143 | 233 | } |
| 234 | + |
| 235 | +/// Extract the Subquery and optional alias from a lateral join's right side. |
| 236 | +fn extract_lateral_subquery( |
| 237 | + plan: &LogicalPlan, |
| 238 | +) -> Option<(Subquery, Option<TableReference>)> { |
| 239 | + match plan { |
| 240 | + LogicalPlan::Subquery(sq) => Some((sq.clone(), None)), |
| 241 | + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { |
| 242 | + if let LogicalPlan::Subquery(sq) = input.as_ref() { |
| 243 | + Some((sq.clone(), Some(alias.clone()))) |
| 244 | + } else { |
| 245 | + None |
| 246 | + } |
| 247 | + } |
| 248 | + _ => None, |
| 249 | + } |
| 250 | +} |
| 251 | + |
| 252 | +/// Rewrite column references in a join filter expression so that columns |
| 253 | +/// belonging to the inner (right) side use the SubqueryAlias qualifier. |
| 254 | +/// |
| 255 | +/// The `PullUpCorrelatedExpr` pass extracts join filters with the inner |
| 256 | +/// columns qualified by their original table names (e.g., `t2.t1_id`). |
| 257 | +/// When the inner plan is wrapped in a `SubqueryAlias("sub")`, those |
| 258 | +/// columns are re-qualified as `sub.t1_id`. This function applies the |
| 259 | +/// same requalification to the filter so it matches the aliased schema. |
| 260 | +fn requalify_filter( |
| 261 | + filter: Expr, |
| 262 | + inner_schema: &DFSchema, |
| 263 | + alias: &TableReference, |
| 264 | +) -> Result<Expr> { |
| 265 | + filter |
| 266 | + .transform(|expr| { |
| 267 | + if let Expr::Column(col) = &expr |
| 268 | + && inner_schema.has_column(col) |
| 269 | + { |
| 270 | + let new_col = Column::new(Some(alias.clone()), col.name.clone()); |
| 271 | + return Ok(Transformed::yes(Expr::Column(new_col))); |
| 272 | + } |
| 273 | + Ok(Transformed::no(expr)) |
| 274 | + }) |
| 275 | + .data() |
| 276 | +} |
0 commit comments