Skip to content

Commit 8923551

Browse files
SUM aggregation enhancement on operations with literal (opensearch-project#3971)
* SUM aggregation enhancement on operations with literal Signed-off-by: Heng Qian <qianheng@amazon.com> * Fix CI Signed-off-by: Heng Qian <qianheng@amazon.com> * Keep ignoring q30 for Calcite Signed-off-by: Heng Qian <qianheng@amazon.com> * Add UT for PPLAggregateConvertRule Signed-off-by: Heng Qian <qianheng@amazon.com> * Add UT for PPLAggregateConvertRule Signed-off-by: Heng Qian <qianheng@amazon.com> * Spotless check Signed-off-by: Heng Qian <qianheng@amazon.com> --------- Signed-off-by: Heng Qian <qianheng@amazon.com> Co-authored-by: Mitchell Gale <mitchell.gale@improving.com>
1 parent fcdb788 commit 8923551

11 files changed

Lines changed: 561 additions & 5 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/plan/OpenSearchRules.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77

88
import com.google.common.collect.ImmutableList;
99
import java.util.List;
10-
import org.apache.calcite.rel.convert.ConverterRule;
10+
import org.apache.calcite.plan.RelOptRule;
1111

1212
public class OpenSearchRules {
13-
public static final List<ConverterRule> OPEN_SEARCH_OPT_RULES = ImmutableList.of();
13+
private static final PPLAggregateConvertRule AGGREGATE_CONVERT_RULE =
14+
PPLAggregateConvertRule.Config.SUM_CONVERTER.toRule();
15+
16+
public static final List<RelOptRule> OPEN_SEARCH_OPT_RULES =
17+
ImmutableList.of(AGGREGATE_CONVERT_RULE);
1418

1519
// prevent instantiation
1620
private OpenSearchRules() {}
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.sql.calcite.plan;
6+
7+
import com.google.common.collect.ImmutableList;
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
10+
import java.util.List;
11+
import java.util.Map;
12+
import java.util.Set;
13+
import java.util.function.Function;
14+
import java.util.stream.IntStream;
15+
import org.apache.calcite.plan.RelOptRuleCall;
16+
import org.apache.calcite.plan.RelOptUtil;
17+
import org.apache.calcite.plan.RelRule;
18+
import org.apache.calcite.rel.RelNode;
19+
import org.apache.calcite.rel.core.AggregateCall;
20+
import org.apache.calcite.rel.core.Project;
21+
import org.apache.calcite.rel.logical.LogicalAggregate;
22+
import org.apache.calcite.rel.logical.LogicalProject;
23+
import org.apache.calcite.rex.RexBuilder;
24+
import org.apache.calcite.rex.RexCall;
25+
import org.apache.calcite.rex.RexInputRef;
26+
import org.apache.calcite.rex.RexLiteral;
27+
import org.apache.calcite.rex.RexNode;
28+
import org.apache.calcite.runtime.PairList;
29+
import org.apache.calcite.sql.SqlKind;
30+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
31+
import org.apache.calcite.tools.RelBuilder;
32+
import org.apache.calcite.util.ImmutableBitSet;
33+
import org.apache.calcite.util.mapping.Mappings;
34+
import org.apache.commons.lang3.tuple.Pair;
35+
import org.immutables.value.Value;
36+
37+
/**
38+
* Planner rule that converts specific aggCall to a more efficient expressions, which includes:
39+
*
40+
* <p>- SUM(FIELD + NUMBER) -> SUM(FIELD) + NUMBER * COUNT()
41+
*
42+
* <p>- SUM(FIELD - NUMBER) -> SUM(FIELD) - NUMBER * COUNT()
43+
*
44+
* <p>- SUM(FIELD * NUMBER) -> SUM(FIELD) * NUMBER
45+
*
46+
* <p>- SUM(FIELD / NUMBER) -> SUM(FIELD) / NUMBER, Don't support this because of precision issue
47+
*
48+
* <p>TODO:
49+
*
50+
* <p>- AVG/MAX/MIN(FIELD [+|-|*|+|/] NUMBER) -> AVG/MAX/MIN(FIELD) [+|-|*|+|/] NUMBER
51+
*/
52+
@Value.Enclosing
53+
public class PPLAggregateConvertRule extends RelRule<PPLAggregateConvertRule.Config> {
54+
55+
/** Creates a OpenSearchAggregateConvertRule. */
56+
protected PPLAggregateConvertRule(Config config) {
57+
super(config);
58+
}
59+
60+
@Override
61+
public void onMatch(RelOptRuleCall call) {
62+
if (call.rels.length == 2) {
63+
final LogicalAggregate aggregate = call.rel(0);
64+
final LogicalProject project = call.rel(1);
65+
apply(call, aggregate, project);
66+
} else {
67+
throw new AssertionError(
68+
String.format(
69+
"The length of rels should be %s but got %s",
70+
this.operands.size(), call.rels.length));
71+
}
72+
}
73+
74+
public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project) {
75+
76+
final RelBuilder relBuilder = call.builder();
77+
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
78+
relBuilder.push(project.getInput());
79+
80+
/*
81+
Build new projects with fields to be used in the converted agg call.
82+
Need to build this project in advance since building converted agg call has dependency on it.
83+
*/
84+
List<AggregateCall> aggCalls = aggregate.getAggCallList();
85+
final List<RexNode> newChildProjects = new ArrayList<>(project.getProjects());
86+
List<Integer> convertedAggCallArgs =
87+
aggCalls.stream()
88+
.filter(aggCall -> isConvertableAggCall(aggCall, project))
89+
.map(
90+
aggCall -> {
91+
RexInputRef rexRef =
92+
getFieldAndLiteral(project.getProjects().get(aggCall.getArgList().getFirst()))
93+
.getLeft();
94+
// Don't remove elements in the child project since we don't know if it will be
95+
// used by
96+
// other aggCall, will handle unused projects later
97+
int ref = newChildProjects.indexOf(rexRef);
98+
if (ref == -1) {
99+
ref = newChildProjects.size();
100+
newChildProjects.add(rexRef);
101+
}
102+
return ref;
103+
})
104+
.toList();
105+
relBuilder.project(newChildProjects);
106+
RelNode newInput = relBuilder.peek();
107+
108+
/* Build converted agg call and its parent projects */
109+
int convertedAggCallCnt = 0;
110+
final int groupSetOffset = aggregate.getGroupSet().cardinality();
111+
final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
112+
final PairList<OperatorConstructor, String> newExprOnAggCall = PairList.of();
113+
for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
114+
AggregateCall aggCall = aggregate.getAggCallList().get(i);
115+
if (isConvertableAggCall(aggCall, project)) {
116+
// The arg ref of convertable aggCall starts at the end of the project
117+
int argRef = convertedAggCallArgs.get(convertedAggCallCnt++);
118+
AggregateCall sumCall =
119+
AggregateCall.create(
120+
aggCall.getParserPosition(),
121+
aggCall.getAggregation(),
122+
aggCall.isDistinct(),
123+
aggCall.isApproximate(),
124+
aggCall.ignoreNulls(),
125+
aggCall.rexList,
126+
ImmutableList.of(argRef),
127+
aggCall.filterArg,
128+
aggCall.distinctKeys,
129+
aggCall.collation,
130+
aggregate.getGroupCount(),
131+
newInput, // Note: must be the new Project
132+
null, // The type will be inferred.
133+
aggCall.getName() + "_SUM");
134+
int sumCallRef = putToDistinctAggregateCalls(distinctAggregateCalls, sumCall);
135+
136+
final Function<RelNode, Function<RexNode, RexNode>> literalConverterProvider;
137+
RexCall rexCall = (RexCall) project.getProjects().get(aggCall.getArgList().getFirst());
138+
if (rexCall.getOperator().kind == SqlKind.PLUS
139+
|| rexCall.getOperator().kind == SqlKind.MINUS) {
140+
AggregateCall countCall =
141+
AggregateCall.create(
142+
aggCall.getParserPosition(),
143+
SqlStdOperatorTable.COUNT,
144+
aggCall.isDistinct(),
145+
aggCall.isApproximate(),
146+
aggCall.ignoreNulls(),
147+
aggCall.rexList,
148+
ImmutableList.of(argRef),
149+
aggCall.filterArg,
150+
aggCall.distinctKeys,
151+
aggCall.collation,
152+
aggregate.getGroupCount(),
153+
newInput,
154+
null, // The type will be inferred.
155+
aggCall.getName() + "_COUNT");
156+
int countCallRef = putToDistinctAggregateCalls(distinctAggregateCalls, countCall);
157+
literalConverterProvider =
158+
input ->
159+
literal ->
160+
rexBuilder.makeCall(
161+
aggCall.getType(),
162+
SqlStdOperatorTable.MULTIPLY,
163+
List.of(
164+
rexBuilder.makeInputRef(input, groupSetOffset + countCallRef),
165+
literal));
166+
} else {
167+
literalConverterProvider = input -> literal -> literal;
168+
}
169+
newExprOnAggCall.add(
170+
input -> {
171+
Function<RexNode, RexNode> fieldConverter =
172+
field -> rexBuilder.makeInputRef(input, groupSetOffset + sumCallRef);
173+
Function<RexNode, RexNode> literalConverter = literalConverterProvider.apply(input);
174+
List<RexNode> operands =
175+
List.of(
176+
convertToNewOperand(
177+
rexCall.getOperands().getFirst(), fieldConverter, literalConverter),
178+
convertToNewOperand(
179+
rexCall.getOperands().getLast(), fieldConverter, literalConverter));
180+
return rexBuilder.makeCall(aggCall.getType(), rexCall.getOperator(), operands);
181+
},
182+
aggCall.getName());
183+
} else {
184+
int callRef = putToDistinctAggregateCalls(distinctAggregateCalls, aggCall);
185+
newExprOnAggCall.add(
186+
input -> rexBuilder.makeInputRef(input, groupSetOffset + callRef), aggCall.getName());
187+
}
188+
}
189+
190+
/* Eliminate unused fields in the child project */
191+
ImmutableBitSet newGroupSet = aggregate.getGroupSet();
192+
;
193+
ImmutableList<ImmutableBitSet> newGroupSets = aggregate.getGroupSets();
194+
;
195+
final Set<Integer> fieldsUsed =
196+
RelOptUtil.getAllFields2(aggregate.getGroupSet(), distinctAggregateCalls);
197+
if (fieldsUsed.size() < newChildProjects.size()) {
198+
// Some fields are computed but not used. Prune them.
199+
final Map<Integer, Integer> sourceFieldToTargetFieldMap = new HashMap<>();
200+
for (int source : fieldsUsed) {
201+
sourceFieldToTargetFieldMap.put(source, sourceFieldToTargetFieldMap.size());
202+
}
203+
newGroupSet = aggregate.getGroupSet().permute(sourceFieldToTargetFieldMap);
204+
newGroupSets =
205+
ImmutableBitSet.ORDERING.immutableSortedCopy(
206+
ImmutableBitSet.permute(aggregate.getGroupSets(), sourceFieldToTargetFieldMap));
207+
final Mappings.TargetMapping targetMapping =
208+
Mappings.target(sourceFieldToTargetFieldMap, newChildProjects.size(), fieldsUsed.size());
209+
final List<AggregateCall> oldAggregateCalls = new ArrayList<>(distinctAggregateCalls);
210+
distinctAggregateCalls.clear();
211+
for (AggregateCall aggregateCall : oldAggregateCalls) {
212+
distinctAggregateCalls.add(aggregateCall.transform(targetMapping));
213+
}
214+
// Project the used fields
215+
relBuilder.project(relBuilder.fields(fieldsUsed.stream().toList()));
216+
}
217+
218+
/* Build the final project-aggregate-project after eliminating unused fields */
219+
relBuilder.aggregate(relBuilder.groupKey(newGroupSet, newGroupSets), distinctAggregateCalls);
220+
List<RexNode> parentProjects =
221+
new ArrayList<>(relBuilder.fields(IntStream.range(0, groupSetOffset).boxed().toList()));
222+
parentProjects.addAll(
223+
newExprOnAggCall.transform(
224+
(constructor, name) ->
225+
aliasMaybe(relBuilder, constructor.apply(relBuilder.peek()), name)));
226+
relBuilder.project(parentProjects);
227+
call.transformTo(relBuilder.build());
228+
}
229+
230+
interface OperatorConstructor {
231+
RexNode apply(RelNode input);
232+
}
233+
234+
private int putToDistinctAggregateCalls(
235+
List<AggregateCall> distinctAggregateCalls, AggregateCall aggCall) {
236+
int i = distinctAggregateCalls.indexOf(aggCall);
237+
if (i < 0) {
238+
i = distinctAggregateCalls.size();
239+
distinctAggregateCalls.add(aggCall);
240+
}
241+
return i;
242+
}
243+
244+
private boolean isConvertableAggCall(AggregateCall aggCall, Project project) {
245+
return aggCall.getAggregation().getKind() == SqlKind.SUM
246+
&& Config.isCallWithLiteral(project.getProjects().get(aggCall.getArgList().getFirst()));
247+
}
248+
249+
private static Pair<RexInputRef, RexLiteral> getFieldAndLiteral(RexNode node) {
250+
RexCall call = (RexCall) node;
251+
RexNode arg1 = call.getOperands().getFirst();
252+
RexNode arg2 = call.getOperands().getLast();
253+
return arg1.getKind() == SqlKind.INPUT_REF
254+
? Pair.of((RexInputRef) arg1, (RexLiteral) arg2)
255+
: Pair.of((RexInputRef) arg2, (RexLiteral) arg1);
256+
}
257+
258+
private static RexNode convertToNewOperand(
259+
RexNode operand,
260+
Function<RexNode, RexNode> fieldConverter,
261+
Function<RexNode, RexNode> literalConverter) {
262+
if (operand.getKind() == SqlKind.INPUT_REF) {
263+
return fieldConverter.apply(operand);
264+
} else {
265+
return literalConverter.apply(operand);
266+
}
267+
}
268+
269+
private RexNode aliasMaybe(RelBuilder builder, RexNode node, String alias) {
270+
return alias == null ? node : builder.alias(node, alias);
271+
}
272+
273+
/** Rule configuration. */
274+
@Value.Immutable
275+
public interface Config extends RelRule.Config {
276+
Config SUM_CONVERTER =
277+
ImmutablePPLAggregateConvertRule.Config.builder()
278+
.build()
279+
.withOperandSupplier(
280+
b0 ->
281+
b0.operand(LogicalAggregate.class)
282+
.predicate(Config::containsSumAggCall)
283+
.oneInput(
284+
b1 ->
285+
b1.operand(LogicalProject.class)
286+
.predicate(Config::containsCallWithNumber)
287+
.anyInputs()));
288+
289+
static boolean containsSumAggCall(LogicalAggregate aggregate) {
290+
return aggregate.getAggCallList().stream()
291+
.anyMatch(aggCall -> aggCall.getAggregation().getKind() == SqlKind.SUM);
292+
}
293+
294+
static boolean containsCallWithNumber(LogicalProject project) {
295+
return project.getProjects().stream().anyMatch(Config::isCallWithLiteral);
296+
}
297+
298+
private static boolean isCallWithLiteral(RexNode node) {
299+
if (CONVERTABLE_FUNCTIONS.contains(node.getKind()) && node instanceof RexCall call) {
300+
RexNode arg1 = call.getOperands().getFirst();
301+
RexNode arg2 = call.getOperands().getLast();
302+
return (arg1.getKind() == SqlKind.INPUT_REF && arg2.getKind() == SqlKind.LITERAL)
303+
|| (arg1.getKind() == SqlKind.LITERAL && arg2.getKind() == SqlKind.INPUT_REF);
304+
}
305+
return false;
306+
}
307+
308+
List<SqlKind> CONVERTABLE_FUNCTIONS =
309+
List.of(
310+
SqlKind.PLUS, SqlKind.MINUS, SqlKind.TIMES
311+
// Don't support division because of the issue of integer division
312+
// e.g. (2000 / 3) * 3 = 1998 while 2000 * 3 / 3 = 2000
313+
// SqlKind.DIVIDE
314+
);
315+
316+
@Override
317+
default PPLAggregateConvertRule toRule() {
318+
return new PPLAggregateConvertRule(this);
319+
}
320+
}
321+
}

core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import org.apache.calcite.util.Holder;
8989
import org.apache.calcite.util.Util;
9090
import org.opensearch.sql.calcite.CalcitePlanContext;
91+
import org.opensearch.sql.calcite.plan.OpenSearchRules;
9192
import org.opensearch.sql.calcite.plan.Scannable;
9293
import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction;
9394

@@ -231,10 +232,15 @@ public <R> R perform(
231232
final RelOptPlanner planner =
232233
createPlanner(
233234
prepareContext, Contexts.of(prepareContext.config()), config.getCostFactory());
235+
registerCustomizedRules(planner);
234236
final RelOptCluster cluster = createCluster(planner, rexBuilder);
235237
return action.apply(cluster, catalogReader, prepareContext.getRootSchema().plus(), statement);
236238
}
237239

240+
private void registerCustomizedRules(RelOptPlanner planner) {
241+
OpenSearchRules.OPEN_SEARCH_OPT_RULES.forEach(planner::addRule);
242+
}
243+
238244
/**
239245
* Customize CalcitePreparingStmt. Override {@link CalcitePrepareImpl#getPreparingStmt} and
240246
* return {@link OpenSearchCalcitePreparingStmt}

0 commit comments

Comments
 (0)