Skip to content

Commit 05ea11e

Browse files
authored
chore: Cleanup and refactor build_join in ScalarSubqueryToJoin (#22316)
## Which issue does this PR close? - N/A ## Rationale for this change This PR cleans up and refactors `build_join`, which is used as part of rewriting correlated subqueries. ## What changes are included in this PR? * This routine only needs to handle correlated subqueries now, so simplify the code and add an assert to that effect * Clarify variable names * Improve comments * Hoist a few variables outside of loops, when possible * Use `when`, `lit` and `not` helpers to build the `CASE` expression ## Are these changes tested? Yes, covered by existing tests. ## Are there any user-facing changes? No, no functional changes at all.
1 parent 541119e commit 05ea11e

1 file changed

Lines changed: 85 additions & 103 deletions

File tree

datafusion/optimizer/src/scalar_subquery_to_join.rs

Lines changed: 85 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use datafusion_common::{Column, Result, ScalarValue, assert_or_internal_err, pla
3434
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
3535
use datafusion_expr::logical_plan::{JoinType, Subquery};
3636
use datafusion_expr::utils::conjunction;
37-
use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, expr};
37+
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, lit, not, when};
3838

3939
/// Optimizer rule that rewrites correlated scalar subquery filters to joins and
4040
/// places an additional projection on top of the filter, to preserve the
@@ -107,18 +107,17 @@ impl OptimizerRule for ScalarSubqueryToJoin {
107107
// iterate through all subqueries in predicate, turning each into a left join
108108
let mut cur_input = filter.input.as_ref().clone();
109109
for (subquery, alias) in subqueries {
110-
if let Some((optimized_subquery, expr_check_map)) =
110+
if let Some((optimized_subquery, compensation_exprs)) =
111111
build_join(&subquery, &cur_input, &alias)?
112112
{
113-
if !expr_check_map.is_empty() {
113+
if !compensation_exprs.is_empty() {
114114
rewrite_expr = rewrite_expr
115115
.transform_up(|expr| {
116-
// replace column references with entry in map, if it exists
117-
if let Some(map_expr) = expr
116+
if let Some(compensation_expr) = expr
118117
.try_as_col()
119-
.and_then(|col| expr_check_map.get(col))
118+
.and_then(|col| compensation_exprs.get(col))
120119
{
121-
Ok(Transformed::yes(map_expr.clone()))
120+
Ok(Transformed::yes(compensation_expr.clone()))
122121
} else {
123122
Ok(Transformed::no(expr))
124123
}
@@ -172,22 +171,21 @@ impl OptimizerRule for ScalarSubqueryToJoin {
172171
// iterate through all subqueries in predicate, turning each into a left join
173172
let mut cur_input = projection.input.as_ref().clone();
174173
for (subquery, alias) in all_subqueries {
175-
if let Some((optimized_subquery, expr_check_map)) =
174+
if let Some((optimized_subquery, compensation_exprs)) =
176175
build_join(&subquery, &cur_input, &alias)?
177176
{
178177
cur_input = optimized_subquery;
179-
if !expr_check_map.is_empty()
178+
if !compensation_exprs.is_empty()
180179
&& let Some(&idx) = alias_to_index.get(&alias)
181180
{
182181
let new_expr = rewrite_exprs[idx]
183182
.clone()
184183
.transform_up(|expr| {
185-
// replace column references with entry in map, if it exists
186-
if let Some(map_expr) = expr
184+
if let Some(compensation_expr) = expr
187185
.try_as_col()
188-
.and_then(|col| expr_check_map.get(col))
186+
.and_then(|col| compensation_exprs.get(col))
189187
{
190-
Ok(Transformed::yes(map_expr.clone()))
188+
Ok(Transformed::yes(compensation_expr.clone()))
191189
} else {
192190
Ok(Transformed::no(expr))
193191
}
@@ -285,133 +283,117 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
285283
///
286284
/// ```text
287285
/// select c.id from customers c
288-
/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
289-
/// where c.balance > o.val
290-
/// ```
291-
///
292-
/// Or a query like:
293-
///
294-
/// ```text
295-
/// select id from customers where balance >
296-
/// (select avg(total) from orders)
297-
/// ```
298-
///
299-
/// and optimizes it into:
300-
///
301-
/// ```text
302-
/// select c.id from customers c
303-
/// left join (select avg(total) as val from orders) a
304-
/// where c.balance > a.val
286+
/// left join (select c_id, avg(total) from orders group by c_id) o
287+
/// on o.c_id = c.id
288+
/// where c.balance > o."avg(total)"
305289
/// ```
306290
///
307291
/// # Arguments
308292
///
309-
/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
310-
/// * `filter_input` - The non-subquery portion (from customers)
311-
/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
312-
/// * `subquery_alias` - Subquery aliases
293+
/// * `subquery` - The correlated scalar subquery to decorrelate.
294+
/// * `outer_input` - The outer plan that the decorrelated subquery is
295+
/// left-joined onto — the input of the `Filter` or `Projection` node
296+
/// that contained the subquery.
297+
/// * `subquery_alias` - The unique alias assigned to the decorrelated
298+
/// subquery; used both to qualify the join condition and to produce
299+
/// column references for the caller to substitute.
300+
///
301+
/// Returns `Ok(None)` if the subquery cannot be decorrelated. On success,
302+
/// returns the rewritten outer plan and a map from each count-bug-affected
303+
/// column to its `CASE WHEN __always_true IS NULL THEN ... END` compensation
304+
/// expression, which the caller must substitute into any expression that
305+
/// references those columns.
313306
fn build_join(
314307
subquery: &Subquery,
315-
filter_input: &LogicalPlan,
308+
outer_input: &LogicalPlan,
316309
subquery_alias: &str,
317310
) -> Result<Option<(LogicalPlan, HashMap<Column, Expr>)>> {
311+
assert_or_internal_err!(
312+
!subquery.outer_ref_columns.is_empty(),
313+
"build_join should only be called for correlated subqueries"
314+
);
318315
let subquery_plan = subquery.subquery.as_ref();
319316
let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
320-
let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
317+
let decorrelated_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
321318
if !pull_up.can_pull_up {
322319
return Ok(None);
323320
}
324321

325-
let collected_count_expr_map =
326-
pull_up.collected_count_expr_map.get(&new_plan).cloned();
327-
let sub_query_alias = LogicalPlanBuilder::from(new_plan)
322+
let collected_count_expr_map = pull_up
323+
.collected_count_expr_map
324+
.get(&decorrelated_subquery)
325+
.cloned();
326+
let aliased_subquery = LogicalPlanBuilder::from(decorrelated_subquery)
328327
.alias(subquery_alias.to_string())?
329328
.build()?;
330329

331-
let mut all_correlated_cols = BTreeSet::new();
332-
pull_up
330+
let all_correlated_cols: BTreeSet<Column> = pull_up
333331
.correlated_subquery_cols_map
334332
.values()
335-
.for_each(|cols| all_correlated_cols.extend(cols.clone()));
333+
.flatten()
334+
.cloned()
335+
.collect();
336336

337-
// alias the join filter
337+
// Correlated columns now live in the decorrelated subquery's output,
338+
// so re-qualify them with the subquery alias.
338339
let join_filter_opt =
339340
conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
340341
replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
341342
})?;
342343

343-
// join our sub query into the main plan
344-
let new_plan = if join_filter_opt.is_none() {
345-
match filter_input {
346-
LogicalPlan::EmptyRelation(EmptyRelation {
347-
produce_one_row: true,
348-
schema: _,
349-
}) => sub_query_alias,
350-
_ => {
351-
// if not correlated, group down to 1 row and left join on that (preserving row count)
352-
LogicalPlanBuilder::from(filter_input.clone())
353-
.join_on(
354-
sub_query_alias,
355-
JoinType::Left,
356-
vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)],
357-
)?
358-
.build()?
359-
}
360-
}
361-
} else {
362-
// left join if correlated, grouping by the join keys so we don't change row count
363-
LogicalPlanBuilder::from(filter_input.clone())
364-
.join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
365-
.build()?
366-
};
367-
let mut computation_project_expr = HashMap::new();
344+
// When pull-up did not extract any usable join keys (a correlated subquery
345+
// whose predicate references only outer columns), fall back to `ON true`:
346+
// the decorrelated subquery still yields at most one row per outer row
347+
// because its aggregate is grouped by the (empty) set of correlated inner
348+
// columns.
349+
let join_filter = join_filter_opt.or_else(|| Some(lit(true)));
350+
351+
let new_plan = LogicalPlanBuilder::from(outer_input.clone())
352+
.join_on(aliased_subquery, JoinType::Left, join_filter)?
353+
.build()?;
354+
355+
// Add count-bug compensation for each of the subquery's projected
356+
// expressions that yield non-NULL values on empty input. We wrap each
357+
// such expression in a CASE that substitutes the empty-input value
358+
// when the LEFT JOIN produced synthetic right-side NULLs (no inner
359+
// row matched), and uses the actual right-side value (which may
360+
// itself be NULL) otherwise.
361+
let mut compensation_exprs = HashMap::new();
368362
if let Some(expr_map) = collected_count_expr_map {
363+
let mut expr_rewrite = TypeCoercionRewriter {
364+
schema: new_plan.schema(),
365+
};
366+
let having_arm = pull_up
367+
.pull_up_having_expr
368+
.as_ref()
369+
.map(|f| (not(f.clone()), lit(ScalarValue::Null)));
369370
for (name, result) in expr_map {
370371
if evaluates_to_null(result.clone(), result.column_refs())? {
371-
// If expr always returns null when column is null, skip processing
372+
// Aggregates whose empty-input value is NULL (max/min/sum/…)
373+
// need no compensation: the LEFT JOIN already produces NULL
374+
// for unmatched outer rows.
372375
continue;
373376
}
374377

375378
let indicator_col =
376379
Column::new(Some(subquery_alias), UN_MATCHED_ROW_INDICATOR);
377380
// Qualify with the subquery alias to avoid ambiguity when the
378381
// outer table has a column with the same name as the aggregate.
379-
let value_col = Column::new(Some(subquery_alias), name.clone());
380-
381-
let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
382-
Expr::Case(expr::Case {
383-
expr: None,
384-
when_then_expr: vec![
385-
(
386-
Box::new(Expr::IsNull(Box::new(Expr::Column(indicator_col)))),
387-
Box::new(result),
388-
),
389-
(
390-
Box::new(Expr::Not(Box::new(filter.clone()))),
391-
Box::new(Expr::Literal(ScalarValue::Null, None)),
392-
),
393-
],
394-
else_expr: Some(Box::new(Expr::Column(value_col.clone()))),
395-
})
396-
} else {
397-
Expr::Case(expr::Case {
398-
expr: None,
399-
when_then_expr: vec![(
400-
Box::new(Expr::IsNull(Box::new(Expr::Column(indicator_col)))),
401-
Box::new(result),
402-
)],
403-
else_expr: Some(Box::new(Expr::Column(value_col.clone()))),
404-
})
405-
};
406-
let mut expr_rewrite = TypeCoercionRewriter {
407-
schema: new_plan.schema(),
408-
};
409-
computation_project_expr
410-
.insert(value_col, computer_expr.rewrite(&mut expr_rewrite).data()?);
382+
let value_col = Column::new(Some(subquery_alias), name);
383+
384+
let mut builder = when(Expr::Column(indicator_col).is_null(), result);
385+
if let Some((when_expr, then_expr)) = &having_arm {
386+
builder = builder.when(when_expr.clone(), then_expr.clone());
387+
}
388+
let compensation_expr = builder.otherwise(Expr::Column(value_col.clone()))?;
389+
compensation_exprs.insert(
390+
value_col,
391+
compensation_expr.rewrite(&mut expr_rewrite).data()?,
392+
);
411393
}
412394
}
413395

414-
Ok(Some((new_plan, computation_project_expr)))
396+
Ok(Some((new_plan, compensation_exprs)))
415397
}
416398

417399
#[cfg(test)]
@@ -425,7 +407,7 @@ mod tests {
425407
use datafusion_expr::test::function_stub::sum;
426408

427409
use crate::assert_optimized_plan_eq_display_indent_snapshot;
428-
use datafusion_expr::{Between, col, lit, out_ref_col, scalar_subquery};
410+
use datafusion_expr::{Between, col, expr, out_ref_col, scalar_subquery};
429411
use datafusion_functions_aggregate::min_max::{max, min};
430412

431413
macro_rules! assert_optimized_plan_equal {

0 commit comments

Comments
 (0)