Skip to content

Commit 4d3875d

Browse files
committed
Handle common agg functions for OTHER category for timechart
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 0257aa5 commit 4d3875d

7 files changed

Lines changed: 266 additions & 88 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,7 +1926,7 @@ public RelNode visitFlatten(Flatten node, CalcitePlanContext context) {
19261926
}
19271927

19281928
/** Helper method to get the function name for proper column naming */
1929-
private String getValueFunctionName(UnresolvedExpression aggregateFunction) {
1929+
private String getAggFieldAlias(UnresolvedExpression aggregateFunction) {
19301930
if (aggregateFunction instanceof Alias) {
19311931
return ((Alias) aggregateFunction).getName();
19321932
}
@@ -1976,15 +1976,15 @@ public RelNode visitTimechart(
19761976

19771977
// Handle no by field case
19781978
if (node.getByField() == null) {
1979-
String valueFunctionName = getValueFunctionName(node.getAggregateFunction());
1979+
String aggFieldAlias = getAggFieldAlias(node.getAggregateFunction());
19801980

19811981
// Create group expression list with just the timestamp span but use a different alias
19821982
// to avoid @timestamp naming conflict
19831983
List<UnresolvedExpression> simpleGroupExprList = new ArrayList<>();
19841984
simpleGroupExprList.add(new Alias("timestamp", spanExpr));
19851985
// Create agg expression list with the aggregate function
19861986
List<UnresolvedExpression> simpleAggExprList =
1987-
List.of(new Alias(valueFunctionName, node.getAggregateFunction()));
1987+
List.of(new Alias(aggFieldAlias, node.getAggregateFunction()));
19881988
// Create an Aggregation object
19891989
Aggregation aggregation =
19901990
new Aggregation(
@@ -1999,9 +1999,9 @@ public RelNode visitTimechart(
19991999
context.relBuilder.push(result);
20002000
// Reorder fields: timestamp first, then count
20012001
context.relBuilder.project(
2002-
context.relBuilder.field("timestamp"), context.relBuilder.field(valueFunctionName));
2002+
context.relBuilder.field("timestamp"), context.relBuilder.field(aggFieldAlias));
20032003
// Rename timestamp to @timestamp
2004-
context.relBuilder.rename(List.of("@timestamp", valueFunctionName));
2004+
context.relBuilder.rename(List.of("@timestamp", aggFieldAlias));
20052005

20062006
context.relBuilder.sort(context.relBuilder.field(0));
20072007
return context.relBuilder.peek();
@@ -2010,7 +2010,7 @@ public RelNode visitTimechart(
20102010
// Extract parameters for byField case
20112011
UnresolvedExpression byField = node.getByField();
20122012
String byFieldName = ((Field) byField).getField().toString();
2013-
String valueFunctionName = getValueFunctionName(node.getAggregateFunction());
2013+
String aggFieldAlias = getAggFieldAlias(node.getAggregateFunction());
20142014

20152015
int limit = Optional.ofNullable(node.getLimit()).orElse(10);
20162016
boolean useOther = Optional.ofNullable(node.getUseOther()).orElse(true);
@@ -2037,11 +2037,11 @@ public RelNode visitTimechart(
20372037

20382038
// Handle no limit case - just sort and return with proper field aliases
20392039
if (limit == 0) {
2040-
// Add final projection with proper aliases: [@timestamp, byField, valueFunctionName]
2040+
// Add final projection with proper aliases: [@timestamp, byField, aggFieldAlias]
20412041
context.relBuilder.project(
20422042
context.relBuilder.alias(context.relBuilder.field(0), "@timestamp"),
20432043
context.relBuilder.alias(context.relBuilder.field(1), byFieldName),
2044-
context.relBuilder.alias(context.relBuilder.field(2), valueFunctionName));
2044+
context.relBuilder.alias(context.relBuilder.field(2), aggFieldAlias));
20452045
context.relBuilder.sort(context.relBuilder.field(0), context.relBuilder.field(1));
20462046
return context.relBuilder.peek();
20472047
}
@@ -2051,32 +2051,61 @@ public RelNode visitTimechart(
20512051

20522052
// Step 2: Find top N categories using window function approach (more efficient than separate
20532053
// aggregation)
2054-
RelNode topCategories = buildTopCategoriesQuery(completeResults, limit, context);
2054+
String aggFunctionName = getAggFunctionName(node.getAggregateFunction());
2055+
Optional<BuiltinFunctionName> aggFuncNameOptional = BuiltinFunctionName.of(aggFunctionName);
2056+
if (aggFuncNameOptional.isEmpty()) {
2057+
throw new IllegalArgumentException(
2058+
StringUtils.format("Unrecognized aggregation function: %s", aggFunctionName));
2059+
}
2060+
BuiltinFunctionName aggFunction = aggFuncNameOptional.get();
2061+
RelNode topCategories = buildTopCategoriesQuery(completeResults, limit, aggFunction, context);
20552062

20562063
// Step 3: Apply OTHER logic with single pass
20572064
return buildFinalResultWithOther(
2058-
completeResults, topCategories, byFieldName, valueFunctionName, useOther, limit, context);
2065+
completeResults,
2066+
topCategories,
2067+
byFieldName,
2068+
aggFunction,
2069+
aggFieldAlias,
2070+
useOther,
2071+
limit,
2072+
context);
20592073

20602074
} catch (Exception e) {
20612075
throw new RuntimeException("Error in visitTimechart: " + e.getMessage(), e);
20622076
}
20632077
}
20642078

2079+
private String getAggFunctionName(UnresolvedExpression aggregateFunction) {
2080+
if (aggregateFunction instanceof Alias alias) {
2081+
return getAggFunctionName(alias.getDelegated());
2082+
}
2083+
return ((AggregateFunction) aggregateFunction).getFuncName();
2084+
}
2085+
20652086
/** Build top categories query - simpler approach that works better with OTHER handling */
20662087
private RelNode buildTopCategoriesQuery(
2067-
RelNode completeResults, int limit, CalcitePlanContext context) {
2088+
RelNode completeResults,
2089+
int limit,
2090+
BuiltinFunctionName aggFunction,
2091+
CalcitePlanContext context) {
20682092
context.relBuilder.push(completeResults);
20692093

20702094
// Filter out null values when determining top categories - null should not count towards limit
20712095
context.relBuilder.filter(context.relBuilder.isNotNull(context.relBuilder.field(1)));
20722096

20732097
// Get totals for non-null categories - field positions: 0=@timestamp, 1=byField, 2=value
2098+
RexInputRef valueField = context.relBuilder.field(2);
2099+
AggCall call = buildAggCall(context.relBuilder, aggFunction, valueField);
2100+
20742101
context.relBuilder.aggregate(
2075-
context.relBuilder.groupKey(context.relBuilder.field(1)),
2076-
context.relBuilder.sum(context.relBuilder.field(2)).as("grand_total"));
2102+
context.relBuilder.groupKey(context.relBuilder.field(1)), call.as("grand_total"));
20772103

20782104
// Apply sorting and limit to non-null categories only
2079-
context.relBuilder.sort(context.relBuilder.desc(context.relBuilder.field("grand_total")));
2105+
RexNode sortField = context.relBuilder.field("grand_total");
2106+
sortField =
2107+
aggFunction == BuiltinFunctionName.MIN ? sortField : context.relBuilder.desc(sortField);
2108+
context.relBuilder.sort(sortField);
20802109
if (limit > 0) {
20812110
context.relBuilder.limit(0, limit);
20822111
}
@@ -2089,18 +2118,25 @@ private RelNode buildFinalResultWithOther(
20892118
RelNode completeResults,
20902119
RelNode topCategories,
20912120
String byFieldName,
2092-
String valueFunctionName,
2121+
BuiltinFunctionName aggFunction,
2122+
String aggFieldAlias,
20932123
boolean useOther,
20942124
int limit,
20952125
CalcitePlanContext context) {
20962126

20972127
// Use zero-filling for count aggregations, standard result for others
2098-
if (valueFunctionName.equals("count")) {
2128+
if (aggFieldAlias.equals("count")) {
20992129
return buildZeroFilledResult(
2100-
completeResults, topCategories, byFieldName, valueFunctionName, useOther, limit, context);
2130+
completeResults, topCategories, byFieldName, aggFieldAlias, useOther, limit, context);
21012131
} else {
21022132
return buildStandardResult(
2103-
completeResults, topCategories, byFieldName, valueFunctionName, useOther, context);
2133+
completeResults,
2134+
topCategories,
2135+
byFieldName,
2136+
aggFunction,
2137+
aggFieldAlias,
2138+
useOther,
2139+
context);
21042140
}
21052141
}
21062142

@@ -2109,7 +2145,8 @@ private RelNode buildStandardResult(
21092145
RelNode completeResults,
21102146
RelNode topCategories,
21112147
String byFieldName,
2112-
String valueFunctionName,
2148+
BuiltinFunctionName aggFunctionName,
2149+
String aggFieldAlias,
21132150
boolean useOther,
21142151
CalcitePlanContext context) {
21152152

@@ -2132,11 +2169,13 @@ private RelNode buildStandardResult(
21322169
context.relBuilder.project(
21332170
context.relBuilder.alias(context.relBuilder.field(0), "@timestamp"),
21342171
context.relBuilder.alias(categoryExpr, byFieldName),
2135-
context.relBuilder.alias(context.relBuilder.field(2), valueFunctionName));
2172+
context.relBuilder.alias(context.relBuilder.field(2), aggFieldAlias));
21362173

2174+
RexInputRef valueField = context.relBuilder.field(2);
2175+
AggCall aggCall = buildAggCall(context.relBuilder, aggFunctionName, valueField);
21372176
context.relBuilder.aggregate(
21382177
context.relBuilder.groupKey(context.relBuilder.field(0), context.relBuilder.field(1)),
2139-
context.relBuilder.sum(context.relBuilder.field(2)).as(valueFunctionName));
2178+
aggCall.as(aggFieldAlias));
21402179

21412180
applyFiltersAndSort(useOther, context);
21422181
return context.relBuilder.peek();
@@ -2171,7 +2210,7 @@ private RelNode buildZeroFilledResult(
21712210
RelNode completeResults,
21722211
RelNode topCategories,
21732212
String byFieldName,
2174-
String valueFunctionName,
2213+
String aggFieldAlias,
21752214
boolean useOther,
21762215
int limit,
21772216
CalcitePlanContext context) {
@@ -2210,7 +2249,7 @@ private RelNode buildZeroFilledResult(
22102249
context.relBuilder.cast(context.relBuilder.field(0), SqlTypeName.TIMESTAMP),
22112250
"@timestamp"),
22122251
context.relBuilder.alias(context.relBuilder.field(1), byFieldName),
2213-
context.relBuilder.alias(context.relBuilder.literal(0), valueFunctionName));
2252+
context.relBuilder.alias(context.relBuilder.literal(0), aggFieldAlias));
22142253
RelNode zeroFilledCombinations = context.relBuilder.build();
22152254

22162255
// Get actual results with OTHER logic applied
@@ -2232,7 +2271,7 @@ private RelNode buildZeroFilledResult(
22322271
context.relBuilder.cast(context.relBuilder.field(0), SqlTypeName.TIMESTAMP),
22332272
"@timestamp"),
22342273
context.relBuilder.alias(actualCategoryExpr, byFieldName),
2235-
context.relBuilder.alias(context.relBuilder.field(2), valueFunctionName));
2274+
context.relBuilder.alias(context.relBuilder.field(2), aggFieldAlias));
22362275

22372276
context.relBuilder.aggregate(
22382277
context.relBuilder.groupKey(context.relBuilder.field(0), context.relBuilder.field(1)),
@@ -2247,12 +2286,30 @@ private RelNode buildZeroFilledResult(
22472286
// Aggregate to combine actual and zero-filled data
22482287
context.relBuilder.aggregate(
22492288
context.relBuilder.groupKey(context.relBuilder.field(0), context.relBuilder.field(1)),
2250-
context.relBuilder.sum(context.relBuilder.field(2)).as(valueFunctionName));
2289+
context.relBuilder.sum(context.relBuilder.field(2)).as(aggFieldAlias));
22512290

22522291
applyFiltersAndSort(useOther, context);
22532292
return context.relBuilder.peek();
22542293
}
22552294

2295+
/**
2296+
* Aggregate a field based on a given built-in aggregation function name.
2297+
*
2298+
* <p>It is intended for secondary aggregations in timechart and chart commands. Using it
2299+
* elsewhere may lead to unintended results. It handles explicitly only MIN, MAX, AVG, COUNT,
2300+
* DISTINCT_COUNT, EARLIEST, and LATEST. It sums the results for the rest aggregation types,
2301+
* assuming them to be accumulative.
2302+
*/
2303+
private AggCall buildAggCall(
2304+
RelBuilder relBuilder, BuiltinFunctionName aggFunction, RexNode node) {
2305+
return switch (aggFunction) {
2306+
case MIN, EARLIEST -> relBuilder.min(node);
2307+
case MAX, LATEST -> relBuilder.max(node);
2308+
case AVG -> relBuilder.avg(node);
2309+
default -> relBuilder.sum(node);
2310+
};
2311+
}
2312+
22562313
@Override
22572314
public RelNode visitTrendline(Trendline node, CalcitePlanContext context) {
22582315
visitChildren(node, context);

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,7 @@ public void testExplainWithReverse() throws IOException {
422422
@Test
423423
public void testExplainWithTimechartAvg() throws IOException {
424424
var result = explainQueryToString("source=events | timechart span=1m avg(cpu_usage) by host");
425-
String expected =
426-
!isPushdownDisabled()
427-
? loadFromFile("expectedOutput/calcite/explain_timechart.yaml")
428-
: loadFromFile("expectedOutput/calcite/explain_timechart_no_pushdown.yaml");
425+
String expected = loadExpectedPlan("explain_timechart.yaml");
429426
assertYamlEqualsJsonIgnoreId(expected, result);
430427
}
431428

integ-test/src/test/resources/expectedOutput/calcite/explain_timechart.yaml

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ calcite:
22
logical: |
33
LogicalSystemLimit(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC], fetch=[10000], type=[QUERY_SIZE_LIMIT])
44
LogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])
5-
LogicalAggregate(group=[{0, 1}], avg(cpu_usage)=[SUM($2)])
5+
LogicalAggregate(group=[{0, 1}], avg(cpu_usage)=[AVG($2)])
66
LogicalProject(@timestamp=[$0], host=[CASE(IS NOT NULL($3), $1, CASE(IS NULL($1), null:NULL, 'OTHER'))], avg(cpu_usage)=[$2])
77
LogicalJoin(condition=[=($1, $3)], joinType=[left])
88
LogicalProject(@timestamp=[$1], host=[$0], $f2=[$2])
99
LogicalAggregate(group=[{0, 2}], agg#0=[AVG($1)])
1010
LogicalProject(host=[$4], cpu_usage=[$7], $f3=[SPAN($1, 1, 'm')])
1111
CalciteLogicalIndexScan(table=[[OpenSearch, events]])
1212
LogicalSort(sort0=[$1], dir0=[DESC], fetch=[10])
13-
LogicalAggregate(group=[{1}], grand_total=[SUM($2)])
13+
LogicalAggregate(group=[{1}], grand_total=[AVG($2)])
1414
LogicalFilter(condition=[IS NOT NULL($1)])
1515
LogicalProject(@timestamp=[$1], host=[$0], $f2=[$2])
1616
LogicalAggregate(group=[{0, 2}], agg#0=[AVG($1)])
@@ -19,19 +19,21 @@ calcite:
1919
physical: |
2020
EnumerableLimit(fetch=[10000])
2121
EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])
22-
EnumerableAggregate(group=[{0, 1}], avg(cpu_usage)=[SUM($2)])
23-
EnumerableCalc(expr#0..4=[{inputs}], expr#5=[IS NOT NULL($t3)], expr#6=[IS NULL($t1)], expr#7=[null:NULL], expr#8=['OTHER'], expr#9=[CASE($t6, $t7, $t8)], expr#10=[CASE($t5, $t1, $t9)], @timestamp=[$t0], host=[$t10], avg(cpu_usage)=[$t2])
24-
EnumerableMergeJoin(condition=[=($1, $3)], joinType=[left])
25-
EnumerableSort(sort0=[$1], dir0=[ASC])
26-
EnumerableCalc(expr#0..3=[{inputs}], expr#4=[0], expr#5=[=($t3, $t4)], expr#6=[null:DOUBLE], expr#7=[CASE($t5, $t6, $t2)], expr#8=[/($t7, $t3)], @timestamp=[$t1], host=[$t0], $f2=[$t8])
27-
EnumerableAggregate(group=[{0, 2}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)])
28-
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=['m'], expr#5=[SPAN($t2, $t3, $t4)], proj#0..1=[{exprs}], $f2=[$t5])
29-
CalciteEnumerableIndexScan(table=[[OpenSearch, events]], PushDownContext=[[PROJECT->[host, cpu_usage, @timestamp]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["host","cpu_usage","@timestamp"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
30-
EnumerableSort(sort0=[$0], dir0=[ASC])
31-
EnumerableLimit(fetch=[10])
32-
EnumerableSort(sort0=[$1], dir0=[DESC])
33-
EnumerableAggregate(group=[{0}], grand_total=[SUM($1)])
34-
EnumerableCalc(expr#0..3=[{inputs}], expr#4=[0], expr#5=[=($t3, $t4)], expr#6=[null:DOUBLE], expr#7=[CASE($t5, $t6, $t2)], expr#8=[/($t7, $t3)], host=[$t0], $f2=[$t8])
35-
EnumerableAggregate(group=[{0, 2}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)])
36-
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=['m'], expr#5=[SPAN($t2, $t3, $t4)], proj#0..1=[{exprs}], $f2=[$t5])
37-
CalciteEnumerableIndexScan(table=[[OpenSearch, events]], PushDownContext=[[PROJECT->[host, cpu_usage, @timestamp], FILTER->IS NOT NULL($0)], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","query":{"exists":{"field":"host","boost":1.0}},"_source":{"includes":["host","cpu_usage","@timestamp"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
22+
EnumerableCalc(expr#0..3=[{inputs}], expr#4=[0], expr#5=[=($t3, $t4)], expr#6=[null:DOUBLE], expr#7=[CASE($t5, $t6, $t2)], expr#8=[/($t7, $t3)], proj#0..1=[{exprs}], avg(cpu_usage)=[$t8])
23+
EnumerableAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT($2)])
24+
EnumerableCalc(expr#0..4=[{inputs}], expr#5=[IS NOT NULL($t3)], expr#6=[IS NULL($t1)], expr#7=[null:NULL], expr#8=['OTHER'], expr#9=[CASE($t6, $t7, $t8)], expr#10=[CASE($t5, $t1, $t9)], @timestamp=[$t0], host=[$t10], avg(cpu_usage)=[$t2])
25+
EnumerableMergeJoin(condition=[=($1, $3)], joinType=[left])
26+
EnumerableSort(sort0=[$1], dir0=[ASC])
27+
EnumerableCalc(expr#0..3=[{inputs}], expr#4=[0], expr#5=[=($t3, $t4)], expr#6=[null:DOUBLE], expr#7=[CASE($t5, $t6, $t2)], expr#8=[/($t7, $t3)], @timestamp=[$t1], host=[$t0], $f2=[$t8])
28+
EnumerableAggregate(group=[{0, 2}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)])
29+
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=['m'], expr#5=[SPAN($t2, $t3, $t4)], proj#0..1=[{exprs}], $f2=[$t5])
30+
CalciteEnumerableIndexScan(table=[[OpenSearch, events]], PushDownContext=[[PROJECT->[host, cpu_usage, @timestamp]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["host","cpu_usage","@timestamp"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
31+
EnumerableSort(sort0=[$0], dir0=[ASC])
32+
EnumerableLimit(fetch=[10])
33+
EnumerableSort(sort0=[$1], dir0=[DESC])
34+
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[0], expr#4=[=($t2, $t3)], expr#5=[null:DOUBLE], expr#6=[CASE($t4, $t5, $t1)], expr#7=[/($t6, $t2)], host=[$t0], grand_total=[$t7])
35+
EnumerableAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)])
36+
EnumerableCalc(expr#0..3=[{inputs}], expr#4=[0], expr#5=[=($t3, $t4)], expr#6=[null:DOUBLE], expr#7=[CASE($t5, $t6, $t2)], expr#8=[/($t7, $t3)], host=[$t0], $f2=[$t8])
37+
EnumerableAggregate(group=[{0, 2}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)])
38+
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=['m'], expr#5=[SPAN($t2, $t3, $t4)], proj#0..1=[{exprs}], $f2=[$t5])
39+
CalciteEnumerableIndexScan(table=[[OpenSearch, events]], PushDownContext=[[PROJECT->[host, cpu_usage, @timestamp], FILTER->IS NOT NULL($0)], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","query":{"exists":{"field":"host","boost":1.0}},"_source":{"includes":["host","cpu_usage","@timestamp"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])

0 commit comments

Comments
 (0)