Skip to content

Commit 1e07c71

Browse files
committed
Rename fields to intended ones after aggregation
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent cba8d02 commit 1e07c71

2 files changed

Lines changed: 119 additions & 2 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,13 +997,49 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
997997
// because that Mapping only works for RexNode, but we need both AggCall and RexNode list.
998998
Pair<List<RexNode>, List<AggCall>> reResolved =
999999
resolveAttributesForAggregation(groupExprList, aggExprList, context);
1000-
1000+
List<String> names = getGroupKeyNamesAfterAggregation(reResolved.getLeft());
10011001
context.relBuilder.aggregate(
10021002
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
1003+
// During aggregation, Calcite projects both input dependencies and output group-by fields.
1004+
// When names conflict, Calcite adds numeric suffixes (e.g., "value0").
1005+
// Apply explicit renaming to restore the intended aliases.
1006+
context.relBuilder.rename(names);
10031007

10041008
return Pair.of(reResolved.getLeft(), reResolved.getRight());
10051009
}
10061010

1011+
private List<String> getGroupKeyNamesAfterAggregation(List<RexNode> nodes) {
1012+
List<RexNode> reordered = new ArrayList<>();
1013+
List<RexNode> left = new ArrayList<>();
1014+
for (RexNode n : nodes) {
1015+
if (isInputRef(n)) {
1016+
reordered.add(n);
1017+
} else {
1018+
left.add(n);
1019+
}
1020+
}
1021+
reordered.addAll(left);
1022+
return reordered.stream()
1023+
.map(this::extractAliasLiteral)
1024+
.flatMap(Optional::stream)
1025+
.map(RexLiteral::stringValue)
1026+
.toList();
1027+
}
1028+
1029+
/**
1030+
* Immitates registerExpression of {@link RelBuilder.Registrar} to derive the output order of
1031+
* group-by keys after aggregation
1032+
*/
1033+
private boolean isInputRef(RexNode node) {
1034+
return switch (node.getKind()) {
1035+
case AS, DESCENDING, NULLS_FIRST, NULLS_LAST -> {
1036+
final List<RexNode> operands = ((RexCall) node).operands;
1037+
yield isInputRef(operands.getFirst());
1038+
}
1039+
default -> node instanceof RexInputRef;
1040+
};
1041+
}
1042+
10071043
/**
10081044
* Resolve attributes for aggregation.
10091045
*
@@ -1102,7 +1138,7 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
11021138
aggregationAttributes.getLeft().stream()
11031139
.map(this::extractAliasLiteral)
11041140
.flatMap(Optional::stream)
1105-
.map(ref -> ((RexLiteral) ref).getValueAs(String.class))
1141+
.map(ref -> ref.getValueAs(String.class))
11061142
.map(context.relBuilder::field)
11071143
.map(f -> (RexNode) f)
11081144
.toList();
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
setup:
2+
- do:
3+
indices.create:
4+
index: time_test
5+
- do:
6+
query.settings:
7+
body:
8+
transient:
9+
plugins.calcite.enabled : true
10+
11+
- do:
12+
bulk:
13+
refresh: true
14+
body:
15+
- '{"index": {"_index": "time_test"}}'
16+
- '{"category":"A","value":1000,"@timestamp":"2024-01-01T00:00:00Z"}'
17+
- '{"index": {"_index": "time_test"}}'
18+
- '{"category":"B","value":2000,"@timestamp":"2024-01-01T00:05:00Z"}'
19+
- '{"index": {"_index": "time_test"}}'
20+
- '{"category":"A","value":1500,"@timestamp":"2024-01-01T00:10:00Z"}'
21+
- '{"index": {"_index": "time_test"}}'
22+
- '{"category":"C","value":3000,"@timestamp":"2024-01-01T00:20:00Z"}'
23+
24+
---
25+
teardown:
26+
- do:
27+
query.settings:
28+
body:
29+
transient:
30+
plugins.calcite.enabled : false
31+
32+
---
33+
"Test span aggregation with field name collision - basic case":
34+
- skip:
35+
features:
36+
- headers
37+
- allowed_warnings
38+
- do:
39+
headers:
40+
Content-Type: 'application/json'
41+
ppl:
42+
body:
43+
query: source=time_test | stats count() by span(value, 1000) as value
44+
45+
- match: { total: 3 }
46+
- match: { schema: [{"name": "count()", "type": "bigint"}, {"name": "value", "type": "bigint"}] }
47+
- match: { datarows: [[2, 1000], [1, 2000], [1, 3000]] }
48+
49+
---
50+
"Test span aggregation with field name collision - multiple aggregations":
51+
- skip:
52+
features:
53+
- headers
54+
- allowed_warnings
55+
- do:
56+
headers:
57+
Content-Type: 'application/json'
58+
ppl:
59+
body:
60+
query: source=time_test | stats count(), avg(value) by span(value, 1000) as value
61+
62+
- match: { total: 3 }
63+
- match: { schema: [{"name": "count()", "type": "bigint"}, {"name": "avg(value)", "type": "double"}, {"name": "value", "type": "bigint"}] }
64+
- match: { datarows: [[2, 1250.0, 1000], [1, 2000.0, 2000], [1, 3000.0, 3000]] }
65+
66+
---
67+
"Test span aggregation without name collision - multiple group-by":
68+
- skip:
69+
features:
70+
- headers
71+
- allowed_warnings
72+
- do:
73+
headers:
74+
Content-Type: 'application/json'
75+
ppl:
76+
body:
77+
query: source=time_test | stats count() by span(@timestamp, 10min) as @timestamp, category, value
78+
79+
- match: { total: 4 }
80+
- match: { schema: [{"name": "count()", "type": "bigint"}, {"name": "@timestamp", "type": "timestamp"}, {"name": "category", "type": "string"}, {"name": "value", "type": "bigint"}] }
81+
- match: { datarows: [[1, "2024-01-01 00:00:00", "A", 1000], [1, "2024-01-01 00:10:00", "A", 1500], [1, "2024-01-01 00:00:00", "B", 2000], [1, "2024-01-01 00:20:00", "C", 3000]] }

0 commit comments

Comments
 (0)