@@ -176,29 +176,25 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
176176 return processChild (optExpression , context );
177177 }
178178
179- // rewrite
180- ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter (project .getColumnRefMap ());
181- context .aggregations .replaceAll ((k , v ) -> (CallOperator ) rewriter .rewrite (v ));
182- context .groupBys .replaceAll ((k , v ) -> rewriter .rewrite (v ));
179+ ColumnRefSet aggUsedColumns = new ColumnRefSet ();
180+ context .aggregations .values ().forEach (v -> aggUsedColumns .union (v .getUsedColumns ()));
183181
184- if (project .getColumnRefMap ().values ().stream ().allMatch (ScalarOperator ::isColumnRef )) {
185- return processChild (optExpression , context );
186- }
182+ Map <ColumnRefOperator , ScalarOperator > columnRefMap = project .getColumnRefMap ();
183+ Map <ColumnRefOperator , ScalarOperator > aggRewriteMap = columnRefMap ;
187184
188185 // handle specials functions case-when/if
189186 // split to groupBys and mock new aggregations by values, don't need to save
190187 // origin predicate, we just do check in collect phase
191- for (Map .Entry <ColumnRefOperator , CallOperator > entry : context . aggregations .entrySet ()) {
192- CallOperator aggFn = entry .getValue ();
193- ScalarOperator aggInput = aggFn . getChild ( 0 );
188+ for (Map .Entry <ColumnRefOperator , ScalarOperator > entry : columnRefMap .entrySet ()) {
189+ ColumnRefOperator key = entry .getKey ();
190+ ScalarOperator value = entry . getValue ( );
194191
195- if (!( aggInput instanceof CallOperator )) {
192+ if (!aggUsedColumns . contains ( key ) || !( value instanceof CallOperator call )) {
196193 continue ;
197194 }
198195
199- CallOperator callInput = (CallOperator ) aggInput ;
200- if (aggInput instanceof CaseWhenOperator ) {
201- CaseWhenOperator caseWhen = (CaseWhenOperator ) aggInput ;
196+ if (call instanceof CaseWhenOperator ) {
197+ CaseWhenOperator caseWhen = (CaseWhenOperator ) value ;
202198 for (ScalarOperator condition : caseWhen .getAllConditionClause ()) {
203199 condition .getUsedColumns ().getStream ().map (factory ::getColumnRef )
204200 .forEach (v -> context .groupBys .put (v , v ));
@@ -218,21 +214,40 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
218214 CaseWhenOperator newCaseWhen = new CaseWhenOperator (caseWhen .getType (), null ,
219215 caseWhen .hasElse () ? caseWhen .getElseClause () : null , newWhenThen );
220216
221- // replace origin
222- aggFn .setChild (0 , newCaseWhen );
223- } else if (callInput .getFunction () != null &&
224- FunctionSet .IF .equals (callInput .getFunction ().getFunctionName ().getFunction ())) {
225- if (aggInput .getChildren ().stream ().skip (1 ).anyMatch (c -> c .isConstant () && !c .isConstantNull ())) {
217+ if (aggRewriteMap == columnRefMap ) {
218+ aggRewriteMap = Maps .newHashMap (columnRefMap );
219+ }
220+ aggRewriteMap .put (key , newCaseWhen );
221+ } else if (call .getFunction () != null &&
222+ FunctionSet .IF .equals (call .getFunction ().getFunctionName ().getFunction ())) {
223+ if (call .getChildren ().stream ().skip (1 ).anyMatch (c -> c .isConstant () && !c .isConstantNull ())) {
226224 // forbidden push down
227225 return visit (optExpression , context );
228226 }
229227
230- aggInput .getChild (0 ).getUsedColumns ().getStream ().map (factory ::getColumnRef )
228+ call .getChild (0 ).getUsedColumns ().getStream ().map (factory ::getColumnRef )
231229 .forEach (v -> context .groupBys .put (v , v ));
232- aggInput .setChild (0 , ConstantOperator .createBoolean (false ));
230+
231+ CallOperator newIf = new CallOperator (call .getFnName (), call .getType (), Lists .newArrayList (call .getArguments ()),
232+ call .getFunction ());
233+ newIf .setChild (0 , ConstantOperator .createBoolean (false ));
234+
235+ if (aggRewriteMap == columnRefMap ) {
236+ aggRewriteMap = Maps .newHashMap (columnRefMap );
237+ }
238+ aggRewriteMap .put (key , newIf );
233239 }
234240 }
235241
242+ ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter (aggRewriteMap );
243+ context .aggregations .replaceAll ((k , v ) -> (CallOperator ) rewriter .rewrite (v ));
244+ if (aggRewriteMap != columnRefMap ) {
245+ ReplaceColumnRefRewriter originalRewriter = new ReplaceColumnRefRewriter (columnRefMap );
246+ context .groupBys .replaceAll ((k , v ) -> originalRewriter .rewrite (v ));
247+ } else {
248+ context .groupBys .replaceAll ((k , v ) -> rewriter .rewrite (v ));
249+ }
250+
236251 // check has constant aggregate, forbidden
237252 if (!context .aggregations .isEmpty () &&
238253 context .aggregations .values ().stream ().allMatch (ScalarOperator ::isConstant )) {
0 commit comments