@@ -175,29 +175,25 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
175175 return processChild (optExpression , context );
176176 }
177177
178- // rewrite
179- ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter (project .getColumnRefMap ());
180- context .aggregations .replaceAll ((k , v ) -> (CallOperator ) rewriter .rewrite (v ));
181- context .groupBys .replaceAll ((k , v ) -> rewriter .rewrite (v ));
178+ ColumnRefSet aggUsedColumns = new ColumnRefSet ();
179+ context .aggregations .values ().forEach (v -> aggUsedColumns .union (v .getUsedColumns ()));
182180
183- if (project .getColumnRefMap ().values ().stream ().allMatch (ScalarOperator ::isColumnRef )) {
184- return processChild (optExpression , context );
185- }
181+ Map <ColumnRefOperator , ScalarOperator > columnRefMap = project .getColumnRefMap ();
182+ Map <ColumnRefOperator , ScalarOperator > aggRewriteMap = columnRefMap ;
186183
187184 // handle specials functions case-when/if
188185 // split to groupBys and mock new aggregations by values, don't need to save
189186 // origin predicate, we just do check in collect phase
190- for (Map .Entry <ColumnRefOperator , CallOperator > entry : context . aggregations .entrySet ()) {
191- CallOperator aggFn = entry .getValue ();
192- ScalarOperator aggInput = aggFn . getChild ( 0 );
187+ for (Map .Entry <ColumnRefOperator , ScalarOperator > entry : columnRefMap .entrySet ()) {
188+ ColumnRefOperator key = entry .getKey ();
189+ ScalarOperator value = entry . getValue ( );
193190
194- if (!( aggInput instanceof CallOperator )) {
191+ if (!aggUsedColumns . contains ( key ) || !( value instanceof CallOperator call )) {
195192 continue ;
196193 }
197194
198- CallOperator callInput = (CallOperator ) aggInput ;
199- if (aggInput instanceof CaseWhenOperator ) {
200- CaseWhenOperator caseWhen = (CaseWhenOperator ) aggInput ;
195+ if (call instanceof CaseWhenOperator ) {
196+ CaseWhenOperator caseWhen = (CaseWhenOperator ) value ;
201197 for (ScalarOperator condition : caseWhen .getAllConditionClause ()) {
202198 condition .getUsedColumns ().getStream ().map (factory ::getColumnRef )
203199 .forEach (v -> context .groupBys .put (v , v ));
@@ -217,21 +213,40 @@ public Void visitLogicalProject(OptExpression optExpression, AggregatePushDownCo
217213 CaseWhenOperator newCaseWhen = new CaseWhenOperator (caseWhen .getType (), null ,
218214 caseWhen .hasElse () ? caseWhen .getElseClause () : null , newWhenThen );
219215
220- // replace origin
221- aggFn .setChild (0 , newCaseWhen );
222- } else if (callInput .getFunction () != null &&
223- FunctionSet .IF .equals (callInput .getFunction ().getFunctionName ().getFunction ())) {
224- if (aggInput .getChildren ().stream ().skip (1 ).anyMatch (c -> c .isConstant () && !c .isConstantNull ())) {
216+ if (aggRewriteMap == columnRefMap ) {
217+ aggRewriteMap = Maps .newHashMap (columnRefMap );
218+ }
219+ aggRewriteMap .put (key , newCaseWhen );
220+ } else if (call .getFunction () != null &&
221+ FunctionSet .IF .equals (call .getFunction ().getFunctionName ().getFunction ())) {
222+ if (call .getChildren ().stream ().skip (1 ).anyMatch (c -> c .isConstant () && !c .isConstantNull ())) {
225223 // forbidden push down
226224 return visit (optExpression , context );
227225 }
228226
229- aggInput .getChild (0 ).getUsedColumns ().getStream ().map (factory ::getColumnRef )
227+ call .getChild (0 ).getUsedColumns ().getStream ().map (factory ::getColumnRef )
230228 .forEach (v -> context .groupBys .put (v , v ));
231- aggInput .setChild (0 , ConstantOperator .createBoolean (false ));
229+
230+ CallOperator newIf = new CallOperator (call .getFnName (), call .getType (), Lists .newArrayList (call .getArguments ()),
231+ call .getFunction ());
232+ newIf .setChild (0 , ConstantOperator .createBoolean (false ));
233+
234+ if (aggRewriteMap == columnRefMap ) {
235+ aggRewriteMap = Maps .newHashMap (columnRefMap );
236+ }
237+ aggRewriteMap .put (key , newIf );
232238 }
233239 }
234240
241+ ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter (aggRewriteMap );
242+ context .aggregations .replaceAll ((k , v ) -> (CallOperator ) rewriter .rewrite (v ));
243+ if (aggRewriteMap != columnRefMap ) {
244+ ReplaceColumnRefRewriter originalRewriter = new ReplaceColumnRefRewriter (columnRefMap );
245+ context .groupBys .replaceAll ((k , v ) -> originalRewriter .rewrite (v ));
246+ } else {
247+ context .groupBys .replaceAll ((k , v ) -> rewriter .rewrite (v ));
248+ }
249+
235250 // check has constant aggregate, forbidden
236251 if (!context .aggregations .isEmpty () &&
237252 context .aggregations .values ().stream ().allMatch (ScalarOperator ::isConstant )) {
0 commit comments