Skip to content

Commit 9db3100

Browse files
mergify[bot]lxynov
andauthored
[BugFix] Fix shared object mutation in PushDownAggregateRewriter for case-when/if (backport #71309) (#71344)
Co-authored-by: Xingyuan Lin <x.lin@celonis.com>
1 parent 368053b commit 9db3100

3 files changed

Lines changed: 129 additions & 24 deletions

File tree

fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -176,29 +176,25 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
176176
return processChild(optExpression, context);
177177
}
178178

179-
// rewrite
180-
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(project.getColumnRefMap());
181-
context.aggregations.replaceAll((k, v) -> (CallOperator) rewriter.rewrite(v));
182-
context.groupBys.replaceAll((k, v) -> rewriter.rewrite(v));
179+
ColumnRefSet aggUsedColumns = new ColumnRefSet();
180+
context.aggregations.values().forEach(v -> aggUsedColumns.union(v.getUsedColumns()));
183181

184-
if (project.getColumnRefMap().values().stream().allMatch(ScalarOperator::isColumnRef)) {
185-
return processChild(optExpression, context);
186-
}
182+
Map<ColumnRefOperator, ScalarOperator> columnRefMap = project.getColumnRefMap();
183+
Map<ColumnRefOperator, ScalarOperator> aggRewriteMap = columnRefMap;
187184

188185
// handle specials functions case-when/if
189186
// split to groupBys and mock new aggregations by values, don't need to save
190187
// origin predicate, we just do check in collect phase
191-
for (Map.Entry<ColumnRefOperator, CallOperator> entry : context.aggregations.entrySet()) {
192-
CallOperator aggFn = entry.getValue();
193-
ScalarOperator aggInput = aggFn.getChild(0);
188+
for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : columnRefMap.entrySet()) {
189+
ColumnRefOperator key = entry.getKey();
190+
ScalarOperator value = entry.getValue();
194191

195-
if (!(aggInput instanceof CallOperator)) {
192+
if (!aggUsedColumns.contains(key) || !(value instanceof CallOperator call)) {
196193
continue;
197194
}
198195

199-
CallOperator callInput = (CallOperator) aggInput;
200-
if (aggInput instanceof CaseWhenOperator) {
201-
CaseWhenOperator caseWhen = (CaseWhenOperator) aggInput;
196+
if (call instanceof CaseWhenOperator) {
197+
CaseWhenOperator caseWhen = (CaseWhenOperator) value;
202198
for (ScalarOperator condition : caseWhen.getAllConditionClause()) {
203199
condition.getUsedColumns().getStream().map(factory::getColumnRef)
204200
.forEach(v -> context.groupBys.put(v, v));
@@ -218,21 +214,40 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
218214
CaseWhenOperator newCaseWhen = new CaseWhenOperator(caseWhen.getType(), null,
219215
caseWhen.hasElse() ? caseWhen.getElseClause() : null, newWhenThen);
220216

221-
// replace origin
222-
aggFn.setChild(0, newCaseWhen);
223-
} else if (callInput.getFunction() != null &&
224-
FunctionSet.IF.equals(callInput.getFunction().getFunctionName().getFunction())) {
225-
if (aggInput.getChildren().stream().skip(1).anyMatch(c -> c.isConstant() && !c.isConstantNull())) {
217+
if (aggRewriteMap == columnRefMap) {
218+
aggRewriteMap = Maps.newHashMap(columnRefMap);
219+
}
220+
aggRewriteMap.put(key, newCaseWhen);
221+
} else if (call.getFunction() != null &&
222+
FunctionSet.IF.equals(call.getFunction().getFunctionName().getFunction())) {
223+
if (call.getChildren().stream().skip(1).anyMatch(c -> c.isConstant() && !c.isConstantNull())) {
226224
// forbidden push down
227225
return visit(optExpression, context);
228226
}
229227

230-
aggInput.getChild(0).getUsedColumns().getStream().map(factory::getColumnRef)
228+
call.getChild(0).getUsedColumns().getStream().map(factory::getColumnRef)
231229
.forEach(v -> context.groupBys.put(v, v));
232-
aggInput.setChild(0, ConstantOperator.createBoolean(false));
230+
231+
CallOperator newIf = new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(call.getArguments()),
232+
call.getFunction());
233+
newIf.setChild(0, ConstantOperator.createBoolean(false));
234+
235+
if (aggRewriteMap == columnRefMap) {
236+
aggRewriteMap = Maps.newHashMap(columnRefMap);
237+
}
238+
aggRewriteMap.put(key, newIf);
233239
}
234240
}
235241

242+
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(aggRewriteMap);
243+
context.aggregations.replaceAll((k, v) -> (CallOperator) rewriter.rewrite(v));
244+
if (aggRewriteMap != columnRefMap) {
245+
ReplaceColumnRefRewriter originalRewriter = new ReplaceColumnRefRewriter(columnRefMap);
246+
context.groupBys.replaceAll((k, v) -> originalRewriter.rewrite(v));
247+
} else {
248+
context.groupBys.replaceAll((k, v) -> rewriter.rewrite(v));
249+
}
250+
236251
// check has constant aggregate, forbidden
237252
if (!context.aggregations.isEmpty() &&
238253
context.aggregations.values().stream().allMatch(ScalarOperator::isConstant)) {

fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateRewriter.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ private void rewriteProject(AggregatePushDownContext context,
192192
&& FunctionSet.IF.equals(((CallOperator) aggExpr).getFunction().getFunctionName().getFunction());
193193

194194
if (isCaseWhen) {
195-
CaseWhenOperator caseWhen = (CaseWhenOperator) aggExpr;
195+
// Clone to avoid mutating the shared object in originProjectMap/project's columnRefMap.
196+
// Without clone, when multiple aggregations reference the same CASE WHEN column,
197+
// the first aggregation's setThenClause/setElseClause corrupts the shared operator.
198+
CaseWhenOperator caseWhen = (CaseWhenOperator) aggExpr.clone();
196199
for (ScalarOperator condition : caseWhen.getAllConditionClause()) {
197200
condition.getUsedColumns().getStream().map(factory::getColumnRef)
198201
.forEach(v -> context.groupBys.put(v, v));
@@ -221,7 +224,8 @@ private void rewriteProject(AggregatePushDownContext context,
221224
context.aggregations.remove(key);
222225
originProjectMap.put(key, new CaseWhenOperator(key.getType(), caseWhen));
223226
} else if (isIfFn) {
224-
CallOperator ifFn = (CallOperator) aggExpr;
227+
// Clone to avoid mutating the shared object (same reason as CaseWhen above).
228+
CallOperator ifFn = (CallOperator) aggExpr.clone();
225229
ifFn.getChild(0).getUsedColumns().getStream().map(factory::getColumnRef)
226230
.forEach(v -> context.groupBys.put(v, v));
227231

fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregatePushDownTest.java

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ public void testPushDownPreAggEnableOnBroadcastJoin() {
9393
}
9494
}
9595

96-
9796
@Test
9897
public void testPushDownDistinctAggBelowWindow()
9998
throws Exception {
@@ -294,4 +293,91 @@ public void testPruneDistinctWindow() throws Exception {
294293
" args: DECIMAL128; result: DECIMAL128(38,2); args nullable: true; result nullable: true], ]");
295294
assertContains(plan, "2:AGGREGATE (update finalize)");
296295
}
296+
297+
@Test
298+
public void testPushDownWithNestedCaseWhenIfs() throws Exception {
299+
String sql = """
300+
WITH cte1 AS (
301+
SELECT
302+
t.t1d AS fk,
303+
t.t1a AS cat,
304+
CASE WHEN t.t1b = 1 THEN t.t1e ELSE t.t1f END AS cval
305+
FROM test_all_type t
306+
),
307+
cte2 AS (
308+
SELECT a.cval, a.fk, a.cat
309+
FROM cte1 a
310+
LEFT JOIN t1 ON a.fk = t1.v4
311+
),
312+
cte3 AS (
313+
SELECT CASE WHEN c.cat THEN c.cval ELSE NULL END gval, c.fk
314+
FROM cte2 c
315+
)
316+
SELECT SUM(gval)
317+
FROM cte3
318+
GROUP BY fk;
319+
""";
320+
String plan = getVerboseExplain(sql);
321+
assertContains(plan, " 2:AGGREGATE (update finalize)\n" +
322+
" | aggregate: sum[([21: cast, DOUBLE, true]); args: DOUBLE; result: DOUBLE; args nullable: true; result" +
323+
" nullable: true], sum[([6: t1f, DOUBLE, true]); args: DOUBLE; result: DOUBLE; args nullable: true; result" +
324+
" nullable: true]\n" +
325+
" | group by: [1: t1a, VARCHAR, true], [2: t1b, SMALLINT, true], [4: t1d, BIGINT, true]\n" +
326+
" | cardinality: 1\n" +
327+
" | \n" +
328+
" 1:Project\n" +
329+
" | output columns:\n" +
330+
" | 1 <-> [1: t1a, VARCHAR, true]\n" +
331+
" | 2 <-> [2: t1b, SMALLINT, true]\n" +
332+
" | 4 <-> [4: t1d, BIGINT, true]\n" +
333+
" | 6 <-> [6: t1f, DOUBLE, true]\n" +
334+
" | 21 <-> cast([5: t1e, FLOAT, true] as DOUBLE)\n" +
335+
" | cardinality: 1\n" +
336+
" | \n" +
337+
" 0:OlapScanNode\n" +
338+
" table: test_all_type, rollup: test_all_type\n" +
339+
" preAggregation: on\n" +
340+
" partitionsRatio=1/1, tabletsRatio=3/3\n" +
341+
" tabletList=10140,10142,10144\n" +
342+
" actualRows=0, avgRowSize=6.0\n" +
343+
" cardinality: 1");
344+
345+
}
346+
347+
@Test
348+
public void testRewriterSharedMutationWithCaseWhen() throws Exception {
349+
// Bug: PushDownAggregateRewriter.rewriteProject() mutates shared CaseWhenOperator
350+
// in-place via setThenClause(). When two aggregations (SUM + MIN) reference the same
351+
// CASE WHEN column, the first aggregation's processing corrupts the CaseWhenOperator,
352+
// causing the second aggregation to see pushed-down column refs instead of original columns.
353+
String sql = "SELECT SUM(sub.cval), MIN(sub.cval), sub.fk " +
354+
"FROM ( " +
355+
" SELECT t1d AS fk, " +
356+
" CASE WHEN t1b = 1 THEN t1e ELSE NULL END AS cval " +
357+
" FROM test_all_type " +
358+
") sub " +
359+
"JOIN t0 ON sub.fk = t0.v1 " +
360+
"GROUP BY sub.fk";
361+
String plan = getVerboseExplain(sql);
362+
363+
assertContains(plan, "sum");
364+
assertContains(plan, "min");
365+
}
366+
367+
@Test
368+
public void testRewriterSharedMutationWithIf() throws Exception {
369+
// Bug: PushDownAggregateRewriter.rewriteProject() mutates shared CallOperator (IF)
370+
// in-place via setChild(). Same root cause as the CaseWhen bug but on the IF path.
371+
String sql = "SELECT SUM(sub.cval), MIN(sub.cval), sub.fk " +
372+
"FROM ( " +
373+
" SELECT t1d AS fk, " +
374+
" IF(t1b = 1, t1e, NULL) AS cval " +
375+
" FROM test_all_type " +
376+
") sub " +
377+
"JOIN t0 ON sub.fk = t0.v1 " +
378+
"GROUP BY sub.fk";
379+
String plan = getVerboseExplain(sql);
380+
assertContains(plan, "sum");
381+
assertContains(plan, "min");
382+
}
297383
}

0 commit comments

Comments
 (0)