@@ -34,7 +34,7 @@ use datafusion_common::{Column, Result, ScalarValue, assert_or_internal_err, pla
3434use datafusion_expr:: expr_rewriter:: create_col_from_scalar_expr;
3535use datafusion_expr:: logical_plan:: { JoinType , Subquery } ;
3636use datafusion_expr:: utils:: conjunction;
37- use datafusion_expr:: { EmptyRelation , Expr , LogicalPlan , LogicalPlanBuilder , expr } ;
37+ use datafusion_expr:: { Expr , LogicalPlan , LogicalPlanBuilder , lit , not , when } ;
3838
3939/// Optimizer rule that rewrites correlated scalar subquery filters to joins and
4040/// places an additional projection on top of the filter, to preserve the
@@ -107,18 +107,17 @@ impl OptimizerRule for ScalarSubqueryToJoin {
107107 // iterate through all subqueries in predicate, turning each into a left join
108108 let mut cur_input = filter. input . as_ref ( ) . clone ( ) ;
109109 for ( subquery, alias) in subqueries {
110- if let Some ( ( optimized_subquery, expr_check_map ) ) =
110+ if let Some ( ( optimized_subquery, compensation_exprs ) ) =
111111 build_join ( & subquery, & cur_input, & alias) ?
112112 {
113- if !expr_check_map . is_empty ( ) {
113+ if !compensation_exprs . is_empty ( ) {
114114 rewrite_expr = rewrite_expr
115115 . transform_up ( |expr| {
116- // replace column references with entry in map, if it exists
117- if let Some ( map_expr) = expr
116+ if let Some ( compensation_expr) = expr
118117 . try_as_col ( )
119- . and_then ( |col| expr_check_map . get ( col) )
118+ . and_then ( |col| compensation_exprs . get ( col) )
120119 {
121- Ok ( Transformed :: yes ( map_expr . clone ( ) ) )
120+ Ok ( Transformed :: yes ( compensation_expr . clone ( ) ) )
122121 } else {
123122 Ok ( Transformed :: no ( expr) )
124123 }
@@ -172,22 +171,21 @@ impl OptimizerRule for ScalarSubqueryToJoin {
172171 // iterate through all subqueries in predicate, turning each into a left join
173172 let mut cur_input = projection. input . as_ref ( ) . clone ( ) ;
174173 for ( subquery, alias) in all_subqueries {
175- if let Some ( ( optimized_subquery, expr_check_map ) ) =
174+ if let Some ( ( optimized_subquery, compensation_exprs ) ) =
176175 build_join ( & subquery, & cur_input, & alias) ?
177176 {
178177 cur_input = optimized_subquery;
179- if !expr_check_map . is_empty ( )
178+ if !compensation_exprs . is_empty ( )
180179 && let Some ( & idx) = alias_to_index. get ( & alias)
181180 {
182181 let new_expr = rewrite_exprs[ idx]
183182 . clone ( )
184183 . transform_up ( |expr| {
185- // replace column references with entry in map, if it exists
186- if let Some ( map_expr) = expr
184+ if let Some ( compensation_expr) = expr
187185 . try_as_col ( )
188- . and_then ( |col| expr_check_map . get ( col) )
186+ . and_then ( |col| compensation_exprs . get ( col) )
189187 {
190- Ok ( Transformed :: yes ( map_expr . clone ( ) ) )
188+ Ok ( Transformed :: yes ( compensation_expr . clone ( ) ) )
191189 } else {
192190 Ok ( Transformed :: no ( expr) )
193191 }
@@ -285,133 +283,117 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
285283///
286284/// ```text
287285/// select c.id from customers c
288- /// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
289- /// where c.balance > o.val
290- /// ```
291- ///
292- /// Or a query like:
293- ///
294- /// ```text
295- /// select id from customers where balance >
296- /// (select avg(total) from orders)
297- /// ```
298- ///
299- /// and optimizes it into:
300- ///
301- /// ```text
302- /// select c.id from customers c
303- /// left join (select avg(total) as val from orders) a
304- /// where c.balance > a.val
286+ /// left join (select c_id, avg(total) from orders group by c_id) o
287+ /// on o.c_id = c.id
288+ /// where c.balance > o."avg(total)"
305289/// ```
306290///
307291/// # Arguments
308292///
309- /// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
310- /// * `filter_input` - The non-subquery portion (from customers)
311- /// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
312- /// * `subquery_alias` - Subquery aliases
293+ /// * `subquery` - The correlated scalar subquery to decorrelate.
294+ /// * `outer_input` - The outer plan that the decorrelated subquery is
295+ /// left-joined onto — the input of the `Filter` or `Projection` node
296+ /// that contained the subquery.
297+ /// * `subquery_alias` - The unique alias assigned to the decorrelated
298+ /// subquery; used both to qualify the join condition and to produce
299+ /// column references for the caller to substitute.
300+ ///
301+ /// Returns `Ok(None)` if the subquery cannot be decorrelated. On success,
302+ /// returns the rewritten outer plan and a map from each count-bug-affected
303+ /// column to its `CASE WHEN __always_true IS NULL THEN ... END` compensation
304+ /// expression, which the caller must substitute into any expression that
305+ /// references those columns.
313306fn build_join (
314307 subquery : & Subquery ,
315- filter_input : & LogicalPlan ,
308+ outer_input : & LogicalPlan ,
316309 subquery_alias : & str ,
317310) -> Result < Option < ( LogicalPlan , HashMap < Column , Expr > ) > > {
311+ assert_or_internal_err ! (
312+ !subquery. outer_ref_columns. is_empty( ) ,
313+ "build_join should only be called for correlated subqueries"
314+ ) ;
318315 let subquery_plan = subquery. subquery . as_ref ( ) ;
319316 let mut pull_up = PullUpCorrelatedExpr :: new ( ) . with_need_handle_count_bug ( true ) ;
320- let new_plan = subquery_plan. clone ( ) . rewrite ( & mut pull_up) . data ( ) ?;
317+ let decorrelated_subquery = subquery_plan. clone ( ) . rewrite ( & mut pull_up) . data ( ) ?;
321318 if !pull_up. can_pull_up {
322319 return Ok ( None ) ;
323320 }
324321
325- let collected_count_expr_map =
326- pull_up. collected_count_expr_map . get ( & new_plan) . cloned ( ) ;
327- let sub_query_alias = LogicalPlanBuilder :: from ( new_plan)
322+ let collected_count_expr_map = pull_up
323+ . collected_count_expr_map
324+ . get ( & decorrelated_subquery)
325+ . cloned ( ) ;
326+ let aliased_subquery = LogicalPlanBuilder :: from ( decorrelated_subquery)
328327 . alias ( subquery_alias. to_string ( ) ) ?
329328 . build ( ) ?;
330329
331- let mut all_correlated_cols = BTreeSet :: new ( ) ;
332- pull_up
330+ let all_correlated_cols: BTreeSet < Column > = pull_up
333331 . correlated_subquery_cols_map
334332 . values ( )
335- . for_each ( |cols| all_correlated_cols. extend ( cols. clone ( ) ) ) ;
333+ . flatten ( )
334+ . cloned ( )
335+ . collect ( ) ;
336336
337- // alias the join filter
337+ // Correlated columns now live in the decorrelated subquery's output,
338+ // so re-qualify them with the subquery alias.
338339 let join_filter_opt =
339340 conjunction ( pull_up. join_filters ) . map_or ( Ok ( None ) , |filter| {
340341 replace_qualified_name ( filter, & all_correlated_cols, subquery_alias) . map ( Some )
341342 } ) ?;
342343
343- // join our sub query into the main plan
344- let new_plan = if join_filter_opt. is_none ( ) {
345- match filter_input {
346- LogicalPlan :: EmptyRelation ( EmptyRelation {
347- produce_one_row : true ,
348- schema : _,
349- } ) => sub_query_alias,
350- _ => {
351- // if not correlated, group down to 1 row and left join on that (preserving row count)
352- LogicalPlanBuilder :: from ( filter_input. clone ( ) )
353- . join_on (
354- sub_query_alias,
355- JoinType :: Left ,
356- vec ! [ Expr :: Literal ( ScalarValue :: Boolean ( Some ( true ) ) , None ) ] ,
357- ) ?
358- . build ( ) ?
359- }
360- }
361- } else {
362- // left join if correlated, grouping by the join keys so we don't change row count
363- LogicalPlanBuilder :: from ( filter_input. clone ( ) )
364- . join_on ( sub_query_alias, JoinType :: Left , join_filter_opt) ?
365- . build ( ) ?
366- } ;
367- let mut computation_project_expr = HashMap :: new ( ) ;
344+ // When pull-up did not extract any usable join keys (a correlated subquery
345+ // whose predicate references only outer columns), fall back to `ON true`:
346+ // the decorrelated subquery still yields at most one row per outer row
347+ // because its aggregate is grouped by the (empty) set of correlated inner
348+ // columns.
349+ let join_filter = join_filter_opt. or_else ( || Some ( lit ( true ) ) ) ;
350+
351+ let new_plan = LogicalPlanBuilder :: from ( outer_input. clone ( ) )
352+ . join_on ( aliased_subquery, JoinType :: Left , join_filter) ?
353+ . build ( ) ?;
354+
355+ // Add count-bug compensation for each of the subquery's projected
356+ // expressions that yield non-NULL values on empty input. We wrap each
357+ // such expression in a CASE that substitutes the empty-input value
358+ // when the LEFT JOIN produced synthetic right-side NULLs (no inner
359+ // row matched), and uses the actual right-side value (which may
360+ // itself be NULL) otherwise.
361+ let mut compensation_exprs = HashMap :: new ( ) ;
368362 if let Some ( expr_map) = collected_count_expr_map {
363+ let mut expr_rewrite = TypeCoercionRewriter {
364+ schema : new_plan. schema ( ) ,
365+ } ;
366+ let having_arm = pull_up
367+ . pull_up_having_expr
368+ . as_ref ( )
369+ . map ( |f| ( not ( f. clone ( ) ) , lit ( ScalarValue :: Null ) ) ) ;
369370 for ( name, result) in expr_map {
370371 if evaluates_to_null ( result. clone ( ) , result. column_refs ( ) ) ? {
371- // If expr always returns null when column is null, skip processing
372+ // Aggregates whose empty-input value is NULL (max/min/sum/…)
373+ // need no compensation: the LEFT JOIN already produces NULL
374+ // for unmatched outer rows.
372375 continue ;
373376 }
374377
375378 let indicator_col =
376379 Column :: new ( Some ( subquery_alias) , UN_MATCHED_ROW_INDICATOR ) ;
377380 // Qualify with the subquery alias to avoid ambiguity when the
378381 // outer table has a column with the same name as the aggregate.
379- let value_col = Column :: new ( Some ( subquery_alias) , name. clone ( ) ) ;
380-
381- let computer_expr = if let Some ( filter) = & pull_up. pull_up_having_expr {
382- Expr :: Case ( expr:: Case {
383- expr : None ,
384- when_then_expr : vec ! [
385- (
386- Box :: new( Expr :: IsNull ( Box :: new( Expr :: Column ( indicator_col) ) ) ) ,
387- Box :: new( result) ,
388- ) ,
389- (
390- Box :: new( Expr :: Not ( Box :: new( filter. clone( ) ) ) ) ,
391- Box :: new( Expr :: Literal ( ScalarValue :: Null , None ) ) ,
392- ) ,
393- ] ,
394- else_expr : Some ( Box :: new ( Expr :: Column ( value_col. clone ( ) ) ) ) ,
395- } )
396- } else {
397- Expr :: Case ( expr:: Case {
398- expr : None ,
399- when_then_expr : vec ! [ (
400- Box :: new( Expr :: IsNull ( Box :: new( Expr :: Column ( indicator_col) ) ) ) ,
401- Box :: new( result) ,
402- ) ] ,
403- else_expr : Some ( Box :: new ( Expr :: Column ( value_col. clone ( ) ) ) ) ,
404- } )
405- } ;
406- let mut expr_rewrite = TypeCoercionRewriter {
407- schema : new_plan. schema ( ) ,
408- } ;
409- computation_project_expr
410- . insert ( value_col, computer_expr. rewrite ( & mut expr_rewrite) . data ( ) ?) ;
382+ let value_col = Column :: new ( Some ( subquery_alias) , name) ;
383+
384+ let mut builder = when ( Expr :: Column ( indicator_col) . is_null ( ) , result) ;
385+ if let Some ( ( when_expr, then_expr) ) = & having_arm {
386+ builder = builder. when ( when_expr. clone ( ) , then_expr. clone ( ) ) ;
387+ }
388+ let compensation_expr = builder. otherwise ( Expr :: Column ( value_col. clone ( ) ) ) ?;
389+ compensation_exprs. insert (
390+ value_col,
391+ compensation_expr. rewrite ( & mut expr_rewrite) . data ( ) ?,
392+ ) ;
411393 }
412394 }
413395
414- Ok ( Some ( ( new_plan, computation_project_expr ) ) )
396+ Ok ( Some ( ( new_plan, compensation_exprs ) ) )
415397}
416398
417399#[ cfg( test) ]
@@ -425,7 +407,7 @@ mod tests {
425407 use datafusion_expr:: test:: function_stub:: sum;
426408
427409 use crate :: assert_optimized_plan_eq_display_indent_snapshot;
428- use datafusion_expr:: { Between , col, lit , out_ref_col, scalar_subquery} ;
410+ use datafusion_expr:: { Between , col, expr , out_ref_col, scalar_subquery} ;
429411 use datafusion_functions_aggregate:: min_max:: { max, min} ;
430412
431413 macro_rules! assert_optimized_plan_equal {
0 commit comments