Skip to content

Commit d0e6271

Browse files
authored
[fix](eagerAgg) Fix eager aggregation incorrectly pushing agg(literal) to nullable side of outer joins (#62107)
### What problem does this PR solve? Issue Number: close #xxx Problem Summary: The EagerAggRewriter only blocked count(*)/count(literal) from being pushed to the nullable side of outer joins (via instanceof Count check). But the same logic applies to ALL aggregate functions whose input slots do not reference columns from the target side, e.g. sum(2), min(1), max(3). For unmatched rows in outer joins, these aggregates should produce their literal-based result (e.g. sum(2) adds 2 per unmatched row). After incorrect pushdown to the nullable side, the pre-aggregated slot becomes NULL for unmatched rows, and sum(NULL)/min(NULL)/max(NULL) loses the contribution entirely. The fix generalizes the existing Count-only guard to all aggregate functions: for any agg function with no input slots from the target side, block pushdown to the nullable side of outer joins. agg(nullable_side_col) is still safe to push because NULL values are naturally handled by aggregates.
1 parent e8b4c75 commit d0e6271

4 files changed

Lines changed: 250 additions & 25 deletions

File tree

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.apache.doris.nereids.trees.expressions.Slot;
2929
import org.apache.doris.nereids.trees.expressions.SlotReference;
3030
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
31-
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
3231
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
3332
import org.apache.doris.nereids.trees.plans.JoinType;
3433
import org.apache.doris.nereids.trees.plans.Plan;
@@ -133,34 +132,32 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, P
133132
}
134133
}
135134

136-
// Do not push count(*)/count(literal)/count(preserved_side_col) to the nullable side of outer joins.
137-
// count(*) counts all physical rows, including null-extended rows from the outer join.
138-
// After pushdown to the nullable side, unmatched rows produce NULL for the pre-aggregated count,
139-
// and ifnull(sum(NULL), 0) = 0, which loses the count of unmatched rows.
140-
// However, count(nullable_side_col) is safe to push down because for unmatched rows,
141-
// nullable_side_col IS NULL, so the original count is 0, matching ifnull(sum(NULL), 0) = 0.
135+
// Do not push agg(literal) or agg(preserved_side_col) to the nullable side of outer joins.
136+
// Aggregates like count(*), sum(2), min(1) etc. aggregate over all physical rows,
137+
// including null-extended rows from the outer join.
138+
// After pushdown to the nullable side, unmatched rows produce NULL for the pre-aggregated value,
139+
// losing the contribution of those rows (e.g. sum(2) should add 2 per unmatched row,
140+
// but sum(NULL) skips them).
141+
// However, agg(nullable_side_col) is safe to push down because for unmatched rows,
142+
// nullable_side_col IS NULL, and the aggregate naturally handles NULL values correctly.
142143
if (!join.getJoinType().isInnerJoin() && !join.getJoinType().isCrossJoin()) {
143144
JoinType joinType = join.getJoinType();
145+
boolean leftIsNullable = joinType.isRightOuterJoin() || joinType.isFullOuterJoin();
146+
boolean rightIsNullable = joinType.isLeftOuterJoin() || joinType.isFullOuterJoin();
144147
for (AggregateFunction aggFunc : context.getAggFunctions()) {
145-
if (aggFunc instanceof Count) {
146-
Set<Slot> countInputSlots = aggFunc.getInputSlots();
147-
// Determine which side is nullable
148-
boolean leftIsNullable = joinType.isRightOuterJoin() || joinType.isFullOuterJoin();
149-
boolean rightIsNullable = joinType.isLeftOuterJoin() || joinType.isFullOuterJoin();
150-
// Check if we're pushing to a nullable side without referencing its columns
151-
if (toLeft && leftIsNullable) {
152-
boolean hasLeftInput = countInputSlots.stream()
153-
.anyMatch(slot -> join.left().getOutputSet().contains(slot));
154-
if (!hasLeftInput) {
155-
toLeft = false;
156-
}
148+
Set<Slot> inputSlots = aggFunc.getInputSlots();
149+
if (toLeft && leftIsNullable) {
150+
boolean hasLeftInput = inputSlots.stream()
151+
.anyMatch(slot -> join.left().getOutputSet().contains(slot));
152+
if (!hasLeftInput) {
153+
toLeft = false;
157154
}
158-
if (toRight && rightIsNullable) {
159-
boolean hasRightInput = countInputSlots.stream()
160-
.anyMatch(slot -> join.right().getOutputSet().contains(slot));
161-
if (!hasRightInput) {
162-
toRight = false;
163-
}
155+
}
156+
if (toRight && rightIsNullable) {
157+
boolean hasRightInput = inputSlots.stream()
158+
.anyMatch(slot -> join.right().getOutputSet().contains(slot));
159+
if (!hasRightInput) {
160+
toRight = false;
164161
}
165162
}
166163
}

fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriterTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,4 +311,59 @@ void testAsofJoinNotPushAgg() {
311311
connectContext.getSessionVariable().setDisableJoinReorder(false);
312312
}
313313
}
314+
315+
@Test
316+
void testNotPushAggLiteralToNullableSideOfOuterJoin() {
317+
// sum(literal), min(literal), max(literal) aggregate over all physical rows,
318+
// including null-extended rows from the outer join.
319+
// Pushing to the nullable side loses the contribution of unmatched rows:
320+
// original: sum(2) on unmatched row = 2
321+
// pushed: sum(NULL) skips the row (wrong!)
322+
// So agg(literal) must NOT be pushed to the nullable side.
323+
connectContext.getSessionVariable().setEagerAggregationMode(1);
324+
connectContext.getSessionVariable().setDisableJoinReorder(true);
325+
try {
326+
// RIGHT JOIN: t1 is the nullable side (left side of RIGHT JOIN)
327+
// sum(2) should NOT be pushed to t1
328+
String sql = "select sum(2), t2.id2 from t1 right join t2"
329+
+ " on t1.id1 = t2.id2 group by t2.id2";
330+
PlanChecker.from(connectContext)
331+
.analyze(sql)
332+
.rewrite()
333+
.nonMatch(logicalJoin(logicalAggregate(), any()))
334+
.printlnTree();
335+
336+
// LEFT JOIN: t2 is the nullable side (right side of LEFT JOIN)
337+
// min(1) should NOT be pushed to t2
338+
sql = "select min(1), t1.id1 from t1 left join t2"
339+
+ " on t1.id1 = t2.id2 group by t1.id1";
340+
PlanChecker.from(connectContext)
341+
.analyze(sql)
342+
.rewrite()
343+
.nonMatch(logicalJoin(any(), logicalAggregate()))
344+
.printlnTree();
345+
346+
// RIGHT JOIN: max(3) should NOT be pushed to nullable left side
347+
sql = "select max(3), t2.id2 from t1 right join t2"
348+
+ " on t1.id1 = t2.id2 group by t2.id2";
349+
PlanChecker.from(connectContext)
350+
.analyze(sql)
351+
.rewrite()
352+
.nonMatch(logicalJoin(logicalAggregate(), any()))
353+
.printlnTree();
354+
355+
// Verify agg(nullable_side_col) is still safe to push (no regression)
356+
// max(t1.name) references the left (nullable) side, so push is allowed
357+
sql = "select max(t1.name), t2.id2 from t1 right join t2"
358+
+ " on t1.id1 = t2.id2 group by t2.id2";
359+
PlanChecker.from(connectContext)
360+
.analyze(sql)
361+
.rewrite()
362+
.matches(logicalAggregate(logicalProject(logicalJoin(logicalAggregate(), any()))))
363+
.printlnTree();
364+
} finally {
365+
connectContext.getSessionVariable().setEagerAggregationMode(0);
366+
connectContext.getSessionVariable().setDisableJoinReorder(false);
367+
}
368+
}
314369
}

regression-test/data/query_p0/eager_agg/eager_agg.out

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,59 @@ Used:
311311
UnUsed:
312312
SyntaxError: leading({ ss broadcast dt } broadcast ws) Msg:can not find table: ws
313313

314+
-- !check_sum_literal_right_join_not_push --
315+
PhysicalResultSink
316+
--hashAgg[GLOBAL]
317+
----hashAgg[LOCAL]
318+
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
319+
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
320+
----------PhysicalOlapScan[eager_agg_t1(a)]
321+
----------PhysicalOlapScan[eager_agg_t2(b)]
322+
--------PhysicalOlapScan[eager_agg_t3(c)]
323+
324+
-- !check_sum_literal_left_join_not_push --
325+
PhysicalResultSink
326+
--hashAgg[GLOBAL]
327+
----hashAgg[LOCAL]
328+
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
329+
--------hashAgg[GLOBAL]
330+
----------hashAgg[LOCAL]
331+
------------PhysicalOlapScan[store_sales]
332+
--------PhysicalOlapScan[date_dim]
333+
334+
-- !check_min_literal_right_join_not_push --
335+
PhysicalResultSink
336+
--hashAgg[GLOBAL]
337+
----hashAgg[LOCAL]
338+
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
339+
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
340+
----------PhysicalOlapScan[eager_agg_t1(a)]
341+
----------PhysicalOlapScan[eager_agg_t2(b)]
342+
--------PhysicalOlapScan[eager_agg_t3(c)]
343+
344+
-- !check_max_literal_left_join_not_push --
345+
PhysicalResultSink
346+
--hashAgg[GLOBAL]
347+
----hashAgg[LOCAL]
348+
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
349+
--------hashAgg[GLOBAL]
350+
----------hashAgg[LOCAL]
351+
------------PhysicalOlapScan[store_sales]
352+
--------PhysicalOlapScan[date_dim]
353+
354+
-- !sum_literal_right_join_eager_off --
355+
\N 4
356+
10 2
357+
358+
-- !sum_literal_right_join_eager_on --
359+
\N 4
360+
10 2
361+
362+
-- !min_literal_right_join_eager_on --
363+
\N 1
364+
10 1
365+
366+
-- !max_literal_right_join_eager_on --
367+
\N 3
368+
10 3
369+

regression-test/suites/query_p0/eager_agg/eager_agg.groovy

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,4 +411,121 @@ suite("eager_agg") {
411411
) t
412412
group by d_year;
413413
"""
414+
415+
// =========================================================================
416+
// Tests for agg(literal) on nullable side of outer joins
417+
// sum(literal), min(literal), max(literal) should NOT be pushed to the
418+
// nullable side of outer joins because unmatched rows lose their contribution.
419+
// =========================================================================
420+
421+
sql """
422+
drop table if exists eager_agg_t1;
423+
drop table if exists eager_agg_t2;
424+
drop table if exists eager_agg_t3;
425+
426+
CREATE TABLE eager_agg_t1 (
427+
id INT NOT NULL,
428+
val INT
429+
) DISTRIBUTED BY HASH(id) BUCKETS 1
430+
PROPERTIES ('replication_num' = '1');
431+
432+
CREATE TABLE eager_agg_t2 (
433+
id INT NOT NULL,
434+
id2 INT NOT NULL
435+
) DISTRIBUTED BY HASH(id) BUCKETS 1
436+
PROPERTIES ('replication_num' = '1');
437+
438+
CREATE TABLE eager_agg_t3 (
439+
id2 INT NOT NULL,
440+
val INT
441+
) DISTRIBUTED BY HASH(id2) BUCKETS 1
442+
PROPERTIES ('replication_num' = '1');
443+
444+
INSERT INTO eager_agg_t1 VALUES (1, 10);
445+
INSERT INTO eager_agg_t2 VALUES (1, 100), (2, 200);
446+
INSERT INTO eager_agg_t3 VALUES (100, 10), (200, 20), (300, 30);
447+
"""
448+
449+
// sum(literal) should NOT be pushed below RIGHT JOIN to the nullable left side
450+
qt_check_sum_literal_right_join_not_push """
451+
explain shape plan
452+
select /*+SET_VAR(eager_aggregation_mode=1, disable_join_reorder = true)*/
453+
a.val, sum(2) as s
454+
from eager_agg_t1 as a
455+
right join eager_agg_t2 as b on a.id = b.id
456+
right join eager_agg_t3 as c on b.id2 = c.id2 and a.val = c.val
457+
group by a.val;
458+
"""
459+
460+
// sum(literal) should NOT be pushed below LEFT JOIN to the nullable right side
461+
qt_check_sum_literal_left_join_not_push """
462+
explain shape plan
463+
select /*+SET_VAR(eager_aggregation_mode=1, disable_join_reorder = true)*/
464+
ss_sales_price, sum(2) as s
465+
from store_sales
466+
left join date_dim on d_date_sk = ss_sold_date_sk
467+
group by ss_sales_price;
468+
"""
469+
470+
// min(literal) should NOT be pushed to nullable side of RIGHT JOIN
471+
qt_check_min_literal_right_join_not_push """
472+
explain shape plan
473+
select /*+SET_VAR(eager_aggregation_mode=1, disable_join_reorder = true)*/
474+
a.val, min(1) as m
475+
from eager_agg_t1 as a
476+
right join eager_agg_t2 as b on a.id = b.id
477+
right join eager_agg_t3 as c on b.id2 = c.id2 and a.val = c.val
478+
group by a.val;
479+
"""
480+
481+
// max(literal) should NOT be pushed to nullable side of LEFT JOIN
482+
qt_check_max_literal_left_join_not_push """
483+
explain shape plan
484+
select /*+SET_VAR(eager_aggregation_mode=1, disable_join_reorder = true)*/
485+
ss_sales_price, max(3) as m
486+
from store_sales
487+
left join date_dim on d_date_sk = ss_sold_date_sk
488+
group by ss_sales_price;
489+
"""
490+
491+
// Execution tests: verify eager agg produces correct results for outer join + literal agg
492+
order_qt_sum_literal_right_join_eager_off """
493+
select /*+SET_VAR(eager_aggregation_mode=-1)*/ /*+ leading(a b c) */
494+
a.val, sum(2) as s
495+
from eager_agg_t1 as a
496+
right join eager_agg_t2 as b on a.id = b.id
497+
right join eager_agg_t3 as c on b.id2 = c.id2 and a.val = c.val
498+
group by a.val
499+
order by a.val;
500+
"""
501+
502+
order_qt_sum_literal_right_join_eager_on """
503+
select /*+SET_VAR(eager_aggregation_mode=1)*/ /*+ leading(a b c) */
504+
a.val, sum(2) as s
505+
from eager_agg_t1 as a
506+
right join eager_agg_t2 as b on a.id = b.id
507+
right join eager_agg_t3 as c on b.id2 = c.id2 and a.val = c.val
508+
group by a.val
509+
order by a.val;
510+
"""
511+
512+
order_qt_min_literal_right_join_eager_on """
513+
select /*+SET_VAR(eager_aggregation_mode=1)*/ /*+ leading(a b c) */
514+
a.val, min(1) as m
515+
from eager_agg_t1 as a
516+
right join eager_agg_t2 as b on a.id = b.id
517+
right join eager_agg_t3 as c on b.id2 = c.id2 and a.val = c.val
518+
group by a.val
519+
order by a.val;
520+
"""
521+
522+
order_qt_max_literal_right_join_eager_on """
523+
select /*+SET_VAR(eager_aggregation_mode=1)*/ /*+ leading(a b c) */
524+
a.val, max(3) as m
525+
from eager_agg_t1 as a
526+
right join eager_agg_t2 as b on a.id = b.id
527+
right join eager_agg_t3 as c on b.id2 = c.id2 and a.val = c.val
528+
group by a.val
529+
order by a.val;
530+
"""
414531
}

0 commit comments

Comments
 (0)