Skip to content

Commit 9c67fef

Browse files
[Backport 2.19-dev] SUM aggregation enhancement on operations with literal (opensearch-project#3971) (opensearch-project#4104)
* SUM aggregation enhancement on operations with literal (opensearch-project#3971) * SUM aggregation enhancement on operations with literal Signed-off-by: Heng Qian <qianheng@amazon.com> * Fix CI Signed-off-by: Heng Qian <qianheng@amazon.com> * Keep ignoring q30 for Calcite Signed-off-by: Heng Qian <qianheng@amazon.com> * Add UT for PPLAggregateConvertRule Signed-off-by: Heng Qian <qianheng@amazon.com> * Add UT for PPLAggregateConvertRule Signed-off-by: Heng Qian <qianheng@amazon.com> * Spotless check Signed-off-by: Heng Qian <qianheng@amazon.com> --------- Signed-off-by: Heng Qian <qianheng@amazon.com> Co-authored-by: Mitchell Gale <mitchell.gale@improving.com> (cherry picked from commit 8923551) * Fix Compiling Signed-off-by: Heng Qian <qianheng@amazon.com> * Fix Compiling Signed-off-by: Heng Qian <qianheng@amazon.com> * Fix Compiling Signed-off-by: Heng Qian <qianheng@amazon.com> --------- Signed-off-by: Heng Qian <qianheng@amazon.com> Co-authored-by: Mitchell Gale <mitchell.gale@improving.com> Co-authored-by: MitchellGale-BitQuill <104795536+mitchellgale-bitquill@users.noreply.github.com>
1 parent f10ffd5 commit 9c67fef

11 files changed

Lines changed: 564 additions & 5 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/plan/OpenSearchRules.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77

88
import com.google.common.collect.ImmutableList;
99
import java.util.List;
10-
import org.apache.calcite.rel.convert.ConverterRule;
10+
import org.apache.calcite.plan.RelOptRule;
1111

1212
public class OpenSearchRules {
13-
public static final List<ConverterRule> OPEN_SEARCH_OPT_RULES = ImmutableList.of();
13+
private static final PPLAggregateConvertRule AGGREGATE_CONVERT_RULE =
14+
PPLAggregateConvertRule.Config.SUM_CONVERTER.toRule();
15+
16+
public static final List<RelOptRule> OPEN_SEARCH_OPT_RULES =
17+
ImmutableList.of(AGGREGATE_CONVERT_RULE);
1418

1519
// prevent instantiation
1620
private OpenSearchRules() {}
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.sql.calcite.plan;
6+
7+
import com.google.common.collect.ImmutableList;
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
10+
import java.util.List;
11+
import java.util.Map;
12+
import java.util.Set;
13+
import java.util.function.Function;
14+
import java.util.stream.Collectors;
15+
import java.util.stream.IntStream;
16+
import org.apache.calcite.plan.RelOptRuleCall;
17+
import org.apache.calcite.plan.RelOptUtil;
18+
import org.apache.calcite.plan.RelRule;
19+
import org.apache.calcite.rel.RelNode;
20+
import org.apache.calcite.rel.core.AggregateCall;
21+
import org.apache.calcite.rel.core.Project;
22+
import org.apache.calcite.rel.logical.LogicalAggregate;
23+
import org.apache.calcite.rel.logical.LogicalProject;
24+
import org.apache.calcite.rex.RexBuilder;
25+
import org.apache.calcite.rex.RexCall;
26+
import org.apache.calcite.rex.RexInputRef;
27+
import org.apache.calcite.rex.RexLiteral;
28+
import org.apache.calcite.rex.RexNode;
29+
import org.apache.calcite.runtime.PairList;
30+
import org.apache.calcite.sql.SqlKind;
31+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
32+
import org.apache.calcite.tools.RelBuilder;
33+
import org.apache.calcite.util.ImmutableBitSet;
34+
import org.apache.calcite.util.mapping.Mappings;
35+
import org.apache.commons.lang3.tuple.Pair;
36+
import org.immutables.value.Value;
37+
38+
/**
39+
* Planner rule that converts specific aggCall to a more efficient expressions, which includes:
40+
*
41+
* <p>- SUM(FIELD + NUMBER) -> SUM(FIELD) + NUMBER * COUNT()
42+
*
43+
* <p>- SUM(FIELD - NUMBER) -> SUM(FIELD) - NUMBER * COUNT()
44+
*
45+
* <p>- SUM(FIELD * NUMBER) -> SUM(FIELD) * NUMBER
46+
*
47+
* <p>- SUM(FIELD / NUMBER) -> SUM(FIELD) / NUMBER, Don't support this because of precision issue
48+
*
49+
* <p>TODO:
50+
*
51+
* <p>- AVG/MAX/MIN(FIELD [+|-|*|+|/] NUMBER) -> AVG/MAX/MIN(FIELD) [+|-|*|+|/] NUMBER
52+
*/
53+
@Value.Enclosing
54+
public class PPLAggregateConvertRule extends RelRule<PPLAggregateConvertRule.Config> {
55+
56+
/** Creates a OpenSearchAggregateConvertRule. */
57+
protected PPLAggregateConvertRule(Config config) {
58+
super(config);
59+
}
60+
61+
@Override
62+
public void onMatch(RelOptRuleCall call) {
63+
if (call.rels.length == 2) {
64+
final LogicalAggregate aggregate = call.rel(0);
65+
final LogicalProject project = call.rel(1);
66+
apply(call, aggregate, project);
67+
} else {
68+
throw new AssertionError(
69+
String.format(
70+
"The length of rels should be %s but got %s",
71+
this.operands.size(), call.rels.length));
72+
}
73+
}
74+
75+
public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project) {
76+
77+
final RelBuilder relBuilder = call.builder();
78+
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
79+
relBuilder.push(project.getInput());
80+
81+
/*
82+
Build new projects with fields to be used in the converted agg call.
83+
Need to build this project in advance since building converted agg call has dependency on it.
84+
*/
85+
List<AggregateCall> aggCalls = aggregate.getAggCallList();
86+
final List<RexNode> newChildProjects = new ArrayList<>(project.getProjects());
87+
List<Integer> convertedAggCallArgs =
88+
aggCalls.stream()
89+
.filter(aggCall -> isConvertableAggCall(aggCall, project))
90+
.map(
91+
aggCall -> {
92+
RexInputRef rexRef =
93+
getFieldAndLiteral(project.getProjects().get(aggCall.getArgList().get(0)))
94+
.getLeft();
95+
// Don't remove elements in the child project since we don't know if it will be
96+
// used by
97+
// other aggCall, will handle unused projects later
98+
int ref = newChildProjects.indexOf(rexRef);
99+
if (ref == -1) {
100+
ref = newChildProjects.size();
101+
newChildProjects.add(rexRef);
102+
}
103+
return ref;
104+
})
105+
.collect(Collectors.toList());
106+
relBuilder.project(newChildProjects);
107+
RelNode newInput = relBuilder.peek();
108+
109+
/* Build converted agg call and its parent projects */
110+
int convertedAggCallCnt = 0;
111+
final int groupSetOffset = aggregate.getGroupSet().cardinality();
112+
final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
113+
final PairList<OperatorConstructor, String> newExprOnAggCall = PairList.of();
114+
for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
115+
AggregateCall aggCall = aggregate.getAggCallList().get(i);
116+
if (isConvertableAggCall(aggCall, project)) {
117+
// The arg ref of convertable aggCall starts at the end of the project
118+
int argRef = convertedAggCallArgs.get(convertedAggCallCnt++);
119+
AggregateCall sumCall =
120+
AggregateCall.create(
121+
aggCall.getParserPosition(),
122+
aggCall.getAggregation(),
123+
aggCall.isDistinct(),
124+
aggCall.isApproximate(),
125+
aggCall.ignoreNulls(),
126+
aggCall.rexList,
127+
ImmutableList.of(argRef),
128+
aggCall.filterArg,
129+
aggCall.distinctKeys,
130+
aggCall.collation,
131+
aggregate.getGroupCount(),
132+
newInput, // Note: must be the new Project
133+
null, // The type will be inferred.
134+
aggCall.getName() + "_SUM");
135+
int sumCallRef = putToDistinctAggregateCalls(distinctAggregateCalls, sumCall);
136+
137+
final Function<RelNode, Function<RexNode, RexNode>> literalConverterProvider;
138+
RexCall rexCall = (RexCall) project.getProjects().get(aggCall.getArgList().get(0));
139+
if (rexCall.getOperator().kind == SqlKind.PLUS
140+
|| rexCall.getOperator().kind == SqlKind.MINUS) {
141+
AggregateCall countCall =
142+
AggregateCall.create(
143+
aggCall.getParserPosition(),
144+
SqlStdOperatorTable.COUNT,
145+
aggCall.isDistinct(),
146+
aggCall.isApproximate(),
147+
aggCall.ignoreNulls(),
148+
aggCall.rexList,
149+
ImmutableList.of(argRef),
150+
aggCall.filterArg,
151+
aggCall.distinctKeys,
152+
aggCall.collation,
153+
aggregate.getGroupCount(),
154+
newInput,
155+
null, // The type will be inferred.
156+
aggCall.getName() + "_COUNT");
157+
int countCallRef = putToDistinctAggregateCalls(distinctAggregateCalls, countCall);
158+
literalConverterProvider =
159+
input ->
160+
literal ->
161+
rexBuilder.makeCall(
162+
aggCall.getType(),
163+
SqlStdOperatorTable.MULTIPLY,
164+
List.of(
165+
rexBuilder.makeInputRef(input, groupSetOffset + countCallRef),
166+
literal));
167+
} else {
168+
literalConverterProvider = input -> literal -> literal;
169+
}
170+
newExprOnAggCall.add(
171+
input -> {
172+
Function<RexNode, RexNode> fieldConverter =
173+
field -> rexBuilder.makeInputRef(input, groupSetOffset + sumCallRef);
174+
Function<RexNode, RexNode> literalConverter = literalConverterProvider.apply(input);
175+
List<RexNode> operands =
176+
List.of(
177+
convertToNewOperand(
178+
rexCall.getOperands().get(0), fieldConverter, literalConverter),
179+
convertToNewOperand(
180+
rexCall.getOperands().get(1), fieldConverter, literalConverter));
181+
return rexBuilder.makeCall(aggCall.getType(), rexCall.getOperator(), operands);
182+
},
183+
aggCall.getName());
184+
} else {
185+
int callRef = putToDistinctAggregateCalls(distinctAggregateCalls, aggCall);
186+
newExprOnAggCall.add(
187+
input -> rexBuilder.makeInputRef(input, groupSetOffset + callRef), aggCall.getName());
188+
}
189+
}
190+
191+
/* Eliminate unused fields in the child project */
192+
ImmutableBitSet newGroupSet = aggregate.getGroupSet();
193+
;
194+
ImmutableList<ImmutableBitSet> newGroupSets = aggregate.getGroupSets();
195+
;
196+
final Set<Integer> fieldsUsed =
197+
RelOptUtil.getAllFields2(aggregate.getGroupSet(), distinctAggregateCalls);
198+
if (fieldsUsed.size() < newChildProjects.size()) {
199+
// Some fields are computed but not used. Prune them.
200+
final Map<Integer, Integer> sourceFieldToTargetFieldMap = new HashMap<>();
201+
for (int source : fieldsUsed) {
202+
sourceFieldToTargetFieldMap.put(source, sourceFieldToTargetFieldMap.size());
203+
}
204+
newGroupSet = aggregate.getGroupSet().permute(sourceFieldToTargetFieldMap);
205+
newGroupSets =
206+
ImmutableBitSet.ORDERING.immutableSortedCopy(
207+
ImmutableBitSet.permute(aggregate.getGroupSets(), sourceFieldToTargetFieldMap));
208+
final Mappings.TargetMapping targetMapping =
209+
Mappings.target(sourceFieldToTargetFieldMap, newChildProjects.size(), fieldsUsed.size());
210+
final List<AggregateCall> oldAggregateCalls = new ArrayList<>(distinctAggregateCalls);
211+
distinctAggregateCalls.clear();
212+
for (AggregateCall aggregateCall : oldAggregateCalls) {
213+
distinctAggregateCalls.add(aggregateCall.transform(targetMapping));
214+
}
215+
// Project the used fields
216+
relBuilder.project(relBuilder.fields(new ArrayList<>(fieldsUsed)));
217+
}
218+
219+
/* Build the final project-aggregate-project after eliminating unused fields */
220+
relBuilder.aggregate(relBuilder.groupKey(newGroupSet, newGroupSets), distinctAggregateCalls);
221+
List<RexNode> parentProjects =
222+
new ArrayList<>(relBuilder.fields(IntStream.range(0, groupSetOffset).boxed().collect(
223+
Collectors.toList())));
224+
parentProjects.addAll(
225+
newExprOnAggCall.transform(
226+
(constructor, name) ->
227+
aliasMaybe(relBuilder, constructor.apply(relBuilder.peek()), name)));
228+
relBuilder.project(parentProjects);
229+
call.transformTo(relBuilder.build());
230+
}
231+
232+
interface OperatorConstructor {
233+
RexNode apply(RelNode input);
234+
}
235+
236+
private int putToDistinctAggregateCalls(
237+
List<AggregateCall> distinctAggregateCalls, AggregateCall aggCall) {
238+
int i = distinctAggregateCalls.indexOf(aggCall);
239+
if (i < 0) {
240+
i = distinctAggregateCalls.size();
241+
distinctAggregateCalls.add(aggCall);
242+
}
243+
return i;
244+
}
245+
246+
private boolean isConvertableAggCall(AggregateCall aggCall, Project project) {
247+
return aggCall.getAggregation().getKind() == SqlKind.SUM
248+
&& Config.isCallWithLiteral(project.getProjects().get(aggCall.getArgList().get(0)));
249+
}
250+
251+
private static Pair<RexInputRef, RexLiteral> getFieldAndLiteral(RexNode node) {
252+
RexCall call = (RexCall) node;
253+
RexNode arg1 = call.getOperands().get(0);
254+
RexNode arg2 = call.getOperands().get(1);
255+
return arg1.getKind() == SqlKind.INPUT_REF
256+
? Pair.of((RexInputRef) arg1, (RexLiteral) arg2)
257+
: Pair.of((RexInputRef) arg2, (RexLiteral) arg1);
258+
}
259+
260+
private static RexNode convertToNewOperand(
261+
RexNode operand,
262+
Function<RexNode, RexNode> fieldConverter,
263+
Function<RexNode, RexNode> literalConverter) {
264+
if (operand.getKind() == SqlKind.INPUT_REF) {
265+
return fieldConverter.apply(operand);
266+
} else {
267+
return literalConverter.apply(operand);
268+
}
269+
}
270+
271+
private RexNode aliasMaybe(RelBuilder builder, RexNode node, String alias) {
272+
return alias == null ? node : builder.alias(node, alias);
273+
}
274+
275+
/** Rule configuration. */
276+
@Value.Immutable
277+
public interface Config extends RelRule.Config {
278+
Config SUM_CONVERTER =
279+
ImmutablePPLAggregateConvertRule.Config.builder()
280+
.build()
281+
.withOperandSupplier(
282+
b0 ->
283+
b0.operand(LogicalAggregate.class)
284+
.predicate(Config::containsSumAggCall)
285+
.oneInput(
286+
b1 ->
287+
b1.operand(LogicalProject.class)
288+
.predicate(Config::containsCallWithNumber)
289+
.anyInputs()));
290+
291+
static boolean containsSumAggCall(LogicalAggregate aggregate) {
292+
return aggregate.getAggCallList().stream()
293+
.anyMatch(aggCall -> aggCall.getAggregation().getKind() == SqlKind.SUM);
294+
}
295+
296+
static boolean containsCallWithNumber(LogicalProject project) {
297+
return project.getProjects().stream().anyMatch(Config::isCallWithLiteral);
298+
}
299+
300+
private static boolean isCallWithLiteral(RexNode node) {
301+
if (CONVERTABLE_FUNCTIONS.contains(node.getKind()) && node instanceof RexCall) {
302+
RexCall call = (RexCall) node;
303+
RexNode arg1 = call.getOperands().get(0);
304+
RexNode arg2 = call.getOperands().get(1);
305+
return (arg1.getKind() == SqlKind.INPUT_REF && arg2.getKind() == SqlKind.LITERAL)
306+
|| (arg1.getKind() == SqlKind.LITERAL && arg2.getKind() == SqlKind.INPUT_REF);
307+
}
308+
return false;
309+
}
310+
311+
List<SqlKind> CONVERTABLE_FUNCTIONS =
312+
List.of(
313+
SqlKind.PLUS, SqlKind.MINUS, SqlKind.TIMES
314+
// Don't support division because of the issue of integer division
315+
// e.g. (2000 / 3) * 3 = 1998 while 2000 * 3 / 3 = 2000
316+
// SqlKind.DIVIDE
317+
);
318+
319+
@Override
320+
default PPLAggregateConvertRule toRule() {
321+
return new PPLAggregateConvertRule(this);
322+
}
323+
}
324+
}

core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import org.apache.calcite.util.Holder;
8989
import org.apache.calcite.util.Util;
9090
import org.opensearch.sql.calcite.CalcitePlanContext;
91+
import org.opensearch.sql.calcite.plan.OpenSearchRules;
9192
import org.opensearch.sql.calcite.plan.Scannable;
9293
import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction;
9394

@@ -231,10 +232,15 @@ public <R> R perform(
231232
final RelOptPlanner planner =
232233
createPlanner(
233234
prepareContext, Contexts.of(prepareContext.config()), config.getCostFactory());
235+
registerCustomizedRules(planner);
234236
final RelOptCluster cluster = createCluster(planner, rexBuilder);
235237
return action.apply(cluster, catalogReader, prepareContext.getRootSchema().plus(), statement);
236238
}
237239

240+
private void registerCustomizedRules(RelOptPlanner planner) {
241+
OpenSearchRules.OPEN_SEARCH_OPT_RULES.forEach(planner::addRule);
242+
}
243+
238244
/**
239245
* Customize CalcitePreparingStmt. Override {@link CalcitePrepareImpl#getPreparingStmt} and
240246
* return {@link OpenSearchCalcitePreparingStmt}

0 commit comments

Comments
 (0)