Skip to content

Commit 7f7495e

Browse files
authored
Support nested aggregation when calcite enabled (#4979) (#5012)
* refactor: throw exception if pushdown cannot be applied * fix tests * fix IT * Support top/dedup/aggregate by nested sub-fields * fix typo * address comments * minor fixing --------- (cherry picked from commit 77633ef) Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent a18e348 commit 7f7495e

43 files changed

Lines changed: 1182 additions & 211 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
153153
import org.opensearch.sql.calcite.utils.BinUtils;
154154
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
155+
import org.opensearch.sql.calcite.utils.PPLHintUtils;
155156
import org.opensearch.sql.calcite.utils.PlanUtils;
156157
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
157158
import org.opensearch.sql.calcite.utils.WildcardUtils;
@@ -948,13 +949,14 @@ private boolean isCountField(RexCall call) {
948949
* @param groupExprList group by expression list
949950
* @param aggExprList aggregate expression list
950951
* @param context CalcitePlanContext
952+
* @param hintIgnoreNullBucket true if bucket_nullable=false
951953
* @return Pair of (group-by list, field list, aggregate list)
952954
*/
953955
private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
954956
List<UnresolvedExpression> groupExprList,
955957
List<UnresolvedExpression> aggExprList,
956958
CalcitePlanContext context,
957-
boolean hintBucketNonNull) {
959+
boolean hintIgnoreNullBucket) {
958960
Pair<List<RexNode>, List<AggCall>> resolved =
959961
resolveAttributesForAggregation(groupExprList, aggExprList, context);
960962
List<RexNode> resolvedGroupByList = resolved.getLeft();
@@ -1048,7 +1050,9 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10481050
// \- Scan t
10491051
List<RexInputRef> trimmedRefs = new ArrayList<>();
10501052
trimmedRefs.addAll(PlanUtils.getInputRefs(resolvedGroupByList)); // group-by keys first
1051-
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolvedAggCallList));
1053+
List<RexInputRef> aggCallRefs = PlanUtils.getInputRefsFromAggCall(resolvedAggCallList);
1054+
boolean hintNestedAgg = containsNestedAggregator(context.relBuilder, aggCallRefs);
1055+
trimmedRefs.addAll(aggCallRefs);
10521056
context.relBuilder.project(trimmedRefs);
10531057

10541058
// Re-resolve all attributes based on adding trimmed Project.
@@ -1060,7 +1064,8 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10601064
List<String> intendedGroupKeyAliases = getGroupKeyNamesAfterAggregation(reResolved.getLeft());
10611065
context.relBuilder.aggregate(
10621066
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
1063-
if (hintBucketNonNull) PlanUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
1067+
if (hintIgnoreNullBucket) PPLHintUtils.addIgnoreNullBucketHintToAggregate(context.relBuilder);
1068+
if (hintNestedAgg) PPLHintUtils.addNestedAggCallHintToAggregate(context.relBuilder);
10641069
// During aggregation, Calcite projects both input dependencies and output group-by fields.
10651070
// When names conflict, Calcite adds numeric suffixes (e.g., "value0").
10661071
// Apply explicit renaming to restore the intended aliases.
@@ -1069,6 +1074,17 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10691074
return Pair.of(reResolved.getLeft(), reResolved.getRight());
10701075
}
10711076

1077+
/**
1078+
* Return true if the aggCalls contains a nested field. For example: aggCalls: [count(),
1079+
* count(a.b)] returns true.
1080+
*/
1081+
private boolean containsNestedAggregator(RelBuilder relBuilder, List<RexInputRef> aggCallRefs) {
1082+
return aggCallRefs.stream()
1083+
.map(r -> relBuilder.peek().getRowType().getFieldNames().get(r.getIndex()))
1084+
.map(name -> org.apache.commons.lang3.StringUtils.substringBefore(name, "."))
1085+
.anyMatch(root -> relBuilder.field(root).getType().getSqlTypeName() == SqlTypeName.ARRAY);
1086+
}
1087+
10721088
/**
10731089
* Imitates {@code Registrar.registerExpression} of {@link RelBuilder} to derive the output order
10741090
* of group-by keys after aggregation.
@@ -1178,8 +1194,8 @@ private void visitAggregation(
11781194
}
11791195
groupExprList.addAll(node.getGroupExprList());
11801196

1181-
// Add stats hint to LogicalAggregation.
1182-
boolean toAddHintsOnAggregate =
1197+
// Add a hint to LogicalAggregation when bucket_nullable=false.
1198+
boolean hintIgnoreNullBucket =
11831199
!groupExprList.isEmpty()
11841200
// This checks if all group-bys should be nonnull
11851201
&& nonNullGroupMask.nextClearBit(0) >= groupExprList.size();
@@ -1199,14 +1215,16 @@ private void visitAggregation(
11991215
.filter(nonNullGroupMask::get)
12001216
.mapToObj(nonNullCandidates::get)
12011217
.collect(Collectors.toList());
1202-
context.relBuilder.filter(
1203-
PlanUtils.getSelectColumns(nonNullFields).stream()
1204-
.map(context.relBuilder::field)
1205-
.map(context.relBuilder::isNotNull)
1206-
.collect(Collectors.toList()));
1218+
if (!nonNullFields.isEmpty()) {
1219+
context.relBuilder.filter(
1220+
PlanUtils.getSelectColumns(nonNullFields).stream()
1221+
.map(context.relBuilder::field)
1222+
.map(context.relBuilder::isNotNull)
1223+
.collect(Collectors.toList()));
1224+
}
12071225

12081226
Pair<List<RexNode>, List<AggCall>> aggregationAttributes =
1209-
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
1227+
aggregateWithTrimming(groupExprList, aggExprList, context, hintIgnoreNullBucket);
12101228

12111229
// schema reordering
12121230
List<RexNode> outputFields = context.relBuilder.fields();
@@ -2342,9 +2360,9 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
23422360

23432361
// if usenull=false, add a isNotNull before Aggregate and the hint to this Aggregate
23442362
Boolean bucketNullable = (Boolean) argumentMap.get(RareTopN.Option.useNull.name()).getValue();
2345-
boolean toAddHintsOnAggregate = false;
2363+
boolean hintIgnoreNullBucket = false;
23462364
if (!bucketNullable && !groupExprList.isEmpty()) {
2347-
toAddHintsOnAggregate = true;
2365+
hintIgnoreNullBucket = true;
23482366
// add isNotNull filter before aggregation to filter out null bucket
23492367
List<RexNode> groupByList =
23502368
groupExprList.stream()
@@ -2356,7 +2374,7 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
23562374
.map(context.relBuilder::isNotNull)
23572375
.collect(Collectors.toList()));
23582376
}
2359-
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
2377+
aggregateWithTrimming(groupExprList, aggExprList, context, hintIgnoreNullBucket);
23602378

23612379
// 2. add count() column with sort direction
23622380
List<RexNode> partitionKeys = rexVisitor.analyze(node.getGroupExprList(), context);

core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
import org.apache.calcite.jdbc.CalcitePrepare;
5353
import org.apache.calcite.jdbc.CalciteSchema;
5454
import org.apache.calcite.jdbc.Driver;
55-
import org.apache.calcite.linq4j.function.Function0;
5655
import org.apache.calcite.plan.Context;
5756
import org.apache.calcite.plan.Contexts;
5857
import org.apache.calcite.plan.Convention;
@@ -175,8 +174,11 @@ public Connection connect(
175174
}
176175

177176
@Override
178-
protected Function0<CalcitePrepare> createPrepareFactory() {
179-
return OpenSearchPrepareImpl::new;
177+
public CalcitePrepare createPrepare() {
178+
if (prepareFactory != null) {
179+
return prepareFactory.get();
180+
}
181+
return new OpenSearchPrepareImpl();
180182
}
181183
}
182184

@@ -298,10 +300,10 @@ public OpenSearchCalcitePreparingStmt(
298300

299301
@Override
300302
protected PreparedResult implement(RelRoot root) {
301-
Hook.PLAN_BEFORE_IMPLEMENTATION.run(root);
302-
RelDataType resultType = root.rel.getRowType();
303-
boolean isDml = root.kind.belongsTo(SqlKind.DML);
304303
if (root.rel instanceof Scannable) {
304+
Hook.PLAN_BEFORE_IMPLEMENTATION.run(root);
305+
RelDataType resultType = root.rel.getRowType();
306+
boolean isDml = root.kind.belongsTo(SqlKind.DML);
305307
final Bindable bindable = dataContext -> ((Scannable) root.rel).scan();
306308

307309
return new PreparedResultImpl(

core/src/main/java/org/opensearch/sql/calcite/utils/PPLHintStrategyTable.java

Lines changed: 0 additions & 33 deletions
This file was deleted.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.utils;
7+
8+
import com.google.common.base.Suppliers;
9+
import java.util.function.Supplier;
10+
import lombok.experimental.UtilityClass;
11+
import org.apache.calcite.rel.core.Aggregate;
12+
import org.apache.calcite.rel.hint.HintStrategyTable;
13+
import org.apache.calcite.rel.hint.RelHint;
14+
import org.apache.calcite.rel.logical.LogicalAggregate;
15+
import org.apache.calcite.tools.RelBuilder;
16+
17+
@UtilityClass
18+
public class PPLHintUtils {
19+
private static final String HINT_AGG_ARGUMENTS = "AGG_ARGS";
20+
private static final String KEY_IGNORE_NULL_BUCKET = "ignoreNullBucket";
21+
private static final String KEY_HAS_NESTED_AGG_CALL = "hasNestedAggCall";
22+
23+
private static final Supplier<HintStrategyTable> HINT_STRATEGY_TABLE =
24+
Suppliers.memoize(
25+
() ->
26+
HintStrategyTable.builder()
27+
.hintStrategy(
28+
HINT_AGG_ARGUMENTS,
29+
(hint, rel) -> {
30+
return rel instanceof LogicalAggregate;
31+
})
32+
// add more here
33+
.build());
34+
35+
/**
36+
* Add hint to aggregate to indicate that the aggregate will ignore null value bucket. Notice, the
37+
* current peek of relBuilder is expected to be LogicalAggregate.
38+
*/
39+
public static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
40+
assert relBuilder.peek() instanceof LogicalAggregate
41+
: "Hint HINT_AGG_ARGUMENTS can be added to LogicalAggregate only";
42+
final RelHint statHint =
43+
RelHint.builder(HINT_AGG_ARGUMENTS).hintOption(KEY_IGNORE_NULL_BUCKET, "true").build();
44+
relBuilder.hints(statHint);
45+
if (relBuilder.getCluster().getHintStrategies() == HintStrategyTable.EMPTY) {
46+
relBuilder.getCluster().setHintStrategies(HINT_STRATEGY_TABLE.get());
47+
}
48+
}
49+
50+
/**
51+
* Add hint to aggregate to indicate that the aggregate has nested agg call. Notice, the current
52+
* peek of relBuilder is expected to be LogicalAggregate.
53+
*/
54+
public static void addNestedAggCallHintToAggregate(RelBuilder relBuilder) {
55+
assert relBuilder.peek() instanceof LogicalAggregate
56+
: "Hint HINT_AGG_ARGUMENTS can be added to LogicalAggregate only";
57+
final RelHint statHint =
58+
RelHint.builder(HINT_AGG_ARGUMENTS).hintOption(KEY_HAS_NESTED_AGG_CALL, "true").build();
59+
relBuilder.hints(statHint);
60+
if (relBuilder.getCluster().getHintStrategies() == HintStrategyTable.EMPTY) {
61+
relBuilder.getCluster().setHintStrategies(HINT_STRATEGY_TABLE.get());
62+
}
63+
}
64+
65+
/** Return true if the aggregate will ignore null value bucket. */
66+
public static boolean ignoreNullBucket(Aggregate aggregate) {
67+
return aggregate.getHints().stream()
68+
.anyMatch(
69+
hint ->
70+
hint.hintName.equals(PPLHintUtils.HINT_AGG_ARGUMENTS)
71+
&& hint.kvOptions.getOrDefault(KEY_IGNORE_NULL_BUCKET, "false").equals("true"));
72+
}
73+
74+
/** Return true if the aggregate has any nested agg call. */
75+
public static boolean hasNestedAggCall(Aggregate aggregate) {
76+
return aggregate.getHints().stream()
77+
.anyMatch(
78+
hint ->
79+
hint.hintName.equals(PPLHintUtils.HINT_AGG_ARGUMENTS)
80+
&& hint.kvOptions
81+
.getOrDefault(KEY_HAS_NESTED_AGG_CALL, "false")
82+
.equals("true"));
83+
}
84+
}

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
import org.apache.calcite.rel.core.Project;
3737
import org.apache.calcite.rel.core.Sort;
3838
import org.apache.calcite.rel.core.TableScan;
39-
import org.apache.calcite.rel.hint.RelHint;
40-
import org.apache.calcite.rel.logical.LogicalAggregate;
4139
import org.apache.calcite.rel.logical.LogicalFilter;
4240
import org.apache.calcite.rel.logical.LogicalProject;
4341
import org.apache.calcite.rel.logical.LogicalSort;
@@ -62,7 +60,6 @@
6260
import org.apache.calcite.util.mapping.Mappings;
6361
import org.opensearch.sql.ast.AbstractNodeVisitor;
6462
import org.opensearch.sql.ast.Node;
65-
import org.opensearch.sql.ast.expression.Argument;
6663
import org.opensearch.sql.ast.expression.IntervalUnit;
6764
import org.opensearch.sql.ast.expression.SpanUnit;
6865
import org.opensearch.sql.ast.expression.WindowBound;
@@ -638,15 +635,6 @@ static void replaceTop(RelBuilder relBuilder, RelNode relNode) {
638635
}
639636
}
640637

641-
static void addIgnoreNullBucketHintToAggregate(RelBuilder relBuilder) {
642-
final RelHint statHits =
643-
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
644-
assert relBuilder.peek() instanceof LogicalAggregate
645-
: "Stats hits should be added to LogicalAggregate";
646-
relBuilder.hints(statHits);
647-
relBuilder.getCluster().setHintStrategies(PPLHintStrategyTable.getHintStrategyTable());
648-
}
649-
650638
/** Extract the RexLiteral from the aggregate call if the aggregate call is a LITERAL_AGG. */
651639
static @Nullable RexLiteral getObjectFromLiteralAgg(AggregateCall aggCall) {
652640
if (aggCall.getAggregation().kind == SqlKind.LITERAL_AGG) {
@@ -683,13 +671,7 @@ private static boolean isNotNullOnRef(RexNode rex) {
683671
&& ((RexCall) rex).getOperands().get(0) instanceof RexInputRef;
684672
}
685673

686-
Predicate<Aggregate> aggIgnoreNullBucket =
687-
agg ->
688-
agg.getHints().stream()
689-
.anyMatch(
690-
hint ->
691-
hint.hintName.equals("stats_args")
692-
&& hint.kvOptions.get(Argument.BUCKET_NULLABLE).equals("false"));
674+
Predicate<Aggregate> aggIgnoreNullBucket = PPLHintUtils::ignoreNullBucket;
693675

694676
Predicate<Aggregate> maybeTimeSpanAgg =
695677
agg ->
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.utils;
7+
8+
import java.util.Iterator;
9+
import java.util.LinkedList;
10+
import java.util.List;
11+
import java.util.Map;
12+
import javax.annotation.Nullable;
13+
import org.apache.commons.lang3.StringUtils;
14+
import org.apache.commons.lang3.tuple.Pair;
15+
import org.opensearch.sql.data.type.ExprCoreType;
16+
import org.opensearch.sql.data.type.ExprType;
17+
18+
public interface Utils {
19+
static <I> List<Pair<I, Integer>> zipWithIndex(List<I> input) {
20+
LinkedList<Pair<I, Integer>> result = new LinkedList<>();
21+
Iterator<I> iter = input.iterator();
22+
int index = 0;
23+
while (iter.hasNext()) {
24+
result.add(Pair.of(iter.next(), index++));
25+
}
26+
return result;
27+
}
28+
29+
/**
30+
* Resolve the nested path from the field name.
31+
*
32+
* @param path the field name
33+
* @param fieldTypes the field types
34+
* @return the nested path if exists, otherwise null
35+
*/
36+
static @Nullable String resolveNestedPath(String path, Map<String, ExprType> fieldTypes) {
37+
if (path == null || fieldTypes == null || fieldTypes.isEmpty()) {
38+
return null;
39+
}
40+
boolean found = false;
41+
String current = path;
42+
String parent = StringUtils.substringBeforeLast(current, ".");
43+
while (parent != null && !parent.equals(current)) {
44+
ExprType pathType = fieldTypes.get(parent);
45+
// Nested is mapped to ExprCoreType.ARRAY
46+
if (pathType == ExprCoreType.ARRAY) {
47+
found = true;
48+
break;
49+
}
50+
current = parent;
51+
parent = StringUtils.substringBeforeLast(current, ".");
52+
}
53+
if (found) {
54+
return parent;
55+
} else {
56+
return null;
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)