Skip to content

Commit d785276

Browse files
committed
Hint non-null in aggregateWithTrimming
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent a57fee6 commit d785276

1 file changed

Lines changed: 25 additions & 22 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -896,12 +896,14 @@ private boolean isCountField(RexCall call) {
896896
* @param groupExprList group by expression list
897897
* @param aggExprList aggregate expression list
898898
* @param context CalcitePlanContext
899+
* @param hintBucketNonNull adda bucket nullable hint on LogicalAggregate if set
899900
* @return Pair of (group-by list, field list, aggregate list)
900901
*/
901902
private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
902903
List<UnresolvedExpression> groupExprList,
903904
List<UnresolvedExpression> aggExprList,
904-
CalcitePlanContext context) {
905+
CalcitePlanContext context,
906+
boolean hintBucketNonNull) {
905907
Pair<List<RexNode>, List<AggCall>> resolved =
906908
resolveAttributesForAggregation(groupExprList, aggExprList, context);
907909
List<RexNode> resolvedGroupByList = resolved.getLeft();
@@ -1005,6 +1007,7 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10051007
List<String> intendedGroupKeyAliases = getGroupKeyNamesAfterAggregation(reResolved.getLeft());
10061008
context.relBuilder.aggregate(
10071009
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
1010+
if (hintBucketNonNull) hintBucketNonNullOnAggregate(context.relBuilder);
10081011
// During aggregation, Calcite projects both input dependencies and output group-by fields.
10091012
// When names conflict, Calcite adds numeric suffixes (e.g., "value0").
10101013
// Apply explicit renaming to restore the intended aliases.
@@ -1013,6 +1016,24 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
10131016
return Pair.of(reResolved.getLeft(), reResolved.getRight());
10141017
}
10151018

1019+
private void hintBucketNonNullOnAggregate(RelBuilder relBuilder) {
1020+
final RelHint statHits =
1021+
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
1022+
assert relBuilder.peek() instanceof LogicalAggregate
1023+
: "Stats hits should be added to LogicalAggregate";
1024+
relBuilder.hints(statHits);
1025+
relBuilder
1026+
.getCluster()
1027+
.setHintStrategies(
1028+
HintStrategyTable.builder()
1029+
.hintStrategy(
1030+
"stats_args",
1031+
(hint, rel) -> {
1032+
return rel instanceof LogicalAggregate;
1033+
})
1034+
.build());
1035+
}
1036+
10161037
/**
10171038
* Imitates {@code Registrar.registerExpression} of {@link RelBuilder} to derive the output order
10181039
* of group-by keys after aggregation.
@@ -1114,25 +1135,7 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
11141135
}
11151136

11161137
Pair<List<RexNode>, List<AggCall>> aggregationAttributes =
1117-
aggregateWithTrimming(groupExprList, aggExprList, context);
1118-
if (toAddHintsOnAggregate) {
1119-
final RelHint statHits =
1120-
RelHint.builder("stats_args").hintOption(Argument.BUCKET_NULLABLE, "false").build();
1121-
assert context.relBuilder.peek() instanceof LogicalAggregate
1122-
: "Stats hits should be added to LogicalAggregate";
1123-
context.relBuilder.hints(statHits);
1124-
context
1125-
.relBuilder
1126-
.getCluster()
1127-
.setHintStrategies(
1128-
HintStrategyTable.builder()
1129-
.hintStrategy(
1130-
"stats_args",
1131-
(hint, rel) -> {
1132-
return rel instanceof LogicalAggregate;
1133-
})
1134-
.build());
1135-
}
1138+
aggregateWithTrimming(groupExprList, aggExprList, context, toAddHintsOnAggregate);
11361139

11371140
// schema reordering
11381141
// As an example, in command `stats count() by colA, colB`,
@@ -1869,7 +1872,7 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
18691872
groupExprList.addAll(fieldList);
18701873
List<UnresolvedExpression> aggExprList =
18711874
List.of(AstDSL.alias(countFieldName, AstDSL.aggregate("count", null)));
1872-
aggregateWithTrimming(groupExprList, aggExprList, context);
1875+
aggregateWithTrimming(groupExprList, aggExprList, context, false);
18731876

18741877
// 2. add a window column
18751878
List<RexNode> partitionKeys = rexVisitor.analyze(node.getGroupExprList(), context);
@@ -2193,7 +2196,7 @@ public RelNode visitTimechart(
21932196
try {
21942197
// Step 1: Initial aggregation - IMPORTANT: order is [spanExpr, byField]
21952198
groupExprList = Arrays.asList(spanExpr, byField);
2196-
aggregateWithTrimming(groupExprList, List.of(node.getAggregateFunction()), context);
2199+
aggregateWithTrimming(groupExprList, List.of(node.getAggregateFunction()), context, false);
21972200

21982201
// First rename the timestamp field (2nd to last) to @timestamp
21992202
List<String> fieldNames = context.relBuilder.peek().getRowType().getFieldNames();

0 commit comments

Comments
 (0)