Skip to content

Commit a57fee6

Browse files
committed
Fix min/earliest order & fix non-accumulative agg for chart
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 02a37ec commit a57fee6

2 files changed

Lines changed: 91 additions & 33 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,37 +2013,35 @@ private String getAggFieldAlias(UnresolvedExpression aggregateFunction) {
20132013
public RelNode visitChart(Chart node, CalcitePlanContext context) {
20142014
visitChildren(node, context);
20152015
ArgumentMap argMap = ArgumentMap.of(node.getArguments());
2016-
List<UnresolvedExpression> groupExprList = new ArrayList<>();
2017-
UnresolvedExpression span;
2018-
if (node.getColumnSplit() instanceof Span && node.getRowSplit() instanceof Span) {
2019-
throw new UnsupportedOperationException("It is not supported to have two span splits");
2020-
} else if (node.getRowSplit() instanceof Span) {
2021-
if (node.getColumnSplit() != null) {
2022-
groupExprList.add(node.getColumnSplit());
2023-
}
2024-
span = node.getRowSplit();
2025-
} else if (node.getColumnSplit() instanceof Span) {
2026-
if (node.getRowSplit() != null) {
2027-
groupExprList.add(node.getRowSplit());
2028-
}
2029-
span = node.getColumnSplit();
2030-
} else {
2031-
groupExprList.addAll(
2032-
Stream.of(node.getRowSplit(), node.getColumnSplit()).filter(Objects::nonNull).toList());
2033-
span = null;
2034-
}
2016+
List<UnresolvedExpression> groupExprList =
2017+
Stream.of(node.getRowSplit(), node.getColumnSplit()).filter(Objects::nonNull).toList();
20352018
Boolean useNull = (Boolean) argMap.getOrDefault("usenull", Chart.DEFAULT_USE_NULL).getValue();
20362019
Aggregation aggregation =
20372020
new Aggregation(
20382021
node.getAggregationFunctions(),
20392022
List.of(),
20402023
groupExprList,
2041-
span,
2024+
null,
20422025
List.of(new Argument(Argument.BUCKET_NULLABLE, AstDSL.booleanLiteral(useNull))));
2043-
visitAggregation(aggregation, context);
2026+
RelNode aggregated = visitAggregation(aggregation, context);
2027+
2028+
// If row or column split does not present or limit equals 0, this is the same as `stats agg
2029+
// [group by col]`
2030+
Integer limit = (Integer) argMap.getOrDefault("limit", Chart.DEFAULT_LIMIT).getValue();
2031+
if (node.getRowSplit() == null || node.getColumnSplit() == null || Objects.equals(limit, 0)) {
2032+
return aggregated;
2033+
}
2034+
2035+
String aggFunctionName = getAggFunctionName(node.getAggregationFunctions().getFirst());
2036+
Optional<BuiltinFunctionName> aggFuncNameOptional = BuiltinFunctionName.of(aggFunctionName);
2037+
if (aggFuncNameOptional.isEmpty()) {
2038+
throw new IllegalArgumentException(
2039+
StringUtils.format("Unrecognized aggregation function: %s", aggFunctionName));
2040+
}
2041+
BuiltinFunctionName aggFunction = aggFuncNameOptional.get();
20442042

20452043
// Convert the column split to string if necessary: column split was supposed to be pivoted to
2046-
// column names. This guarantees that its type being compatible with useother and usenull
2044+
// column names. This guarantees that its type compatibility with useother and usenull
20472045
RelBuilder relBuilder = context.relBuilder;
20482046
RexNode colSplit = relBuilder.field(2);
20492047
String columSplitName = relBuilder.peek().getRowType().getFieldNames().getLast();
@@ -2055,14 +2053,7 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
20552053
columSplitName);
20562054
}
20572055
relBuilder.project(relBuilder.field(0), relBuilder.field(1), colSplit);
2058-
RelNode aggregated = relBuilder.peek();
2059-
2060-
// If row or column split does not present or limit equals 0, this is the same as `stats agg
2061-
// [group by col]`
2062-
Integer limit = (Integer) argMap.getOrDefault("limit", Chart.DEFAULT_LIMIT).getValue();
2063-
if (node.getRowSplit() == null || node.getColumnSplit() == null || Objects.equals(limit, 0)) {
2064-
return aggregated;
2065-
}
2056+
aggregated = relBuilder.peek();
20662057

20672058
Boolean top = (Boolean) argMap.getOrDefault("top", Chart.DEFAULT_TOP).getValue();
20682059
Boolean useOther =
@@ -2075,11 +2066,16 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
20752066
// 1: column split; 0: agg
20762067
relBuilder.aggregate(
20772068
relBuilder.groupKey(relBuilder.field(1)),
2078-
relBuilder.sum(relBuilder.field(0)).as("__grand_total__")); // results: group key, agg calls
2069+
buildAggCall(context.relBuilder, aggFunction, relBuilder.field(0))
2070+
.as("__grand_total__")); // results: group key, agg calls
20792071
RexNode grandTotal = relBuilder.field("__grand_total__");
2080-
if (top) {
2072+
// Apply sorting: for MIN/EARLIEST, reverse the top/bottom logic
2073+
boolean smallestFirst =
2074+
aggFunction == BuiltinFunctionName.MIN || aggFunction == BuiltinFunctionName.EARLIEST;
2075+
if (top != smallestFirst) {
20812076
grandTotal = relBuilder.desc(grandTotal);
20822077
}
2078+
20832079
// Always set it to null last so that it does not interfere with top / bottom calculation
20842080
grandTotal = relBuilder.nullsLast(grandTotal);
20852081
RexNode rowNum =
@@ -2138,7 +2134,7 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
21382134
relBuilder.alias(columnSplitExpr, columSplitName));
21392135
relBuilder.aggregate(
21402136
relBuilder.groupKey(relBuilder.field(1), relBuilder.field(2)),
2141-
relBuilder.sum(relBuilder.field(0)).as(aggFieldName));
2137+
buildAggCall(context.relBuilder, aggFunction, relBuilder.field(0)).as(aggFieldName));
21422138
return relBuilder.peek();
21432139
}
21442140

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.remote;
7+
8+
import org.json.JSONObject;
9+
import org.junit.jupiter.api.Test;
10+
import org.opensearch.sql.ppl.PPLIntegTestCase;
11+
12+
import java.io.IOException;
13+
14+
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
15+
import static org.opensearch.sql.util.MatcherUtils.assertJsonEquals;
16+
import static org.opensearch.sql.util.MatcherUtils.rows;
17+
import static org.opensearch.sql.util.MatcherUtils.schema;
18+
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
19+
import static org.opensearch.sql.util.MatcherUtils.verifySchema;
20+
21+
public class CalciteChartCommandIT extends PPLIntegTestCase {
22+
@Override
23+
protected void init() throws Exception {
24+
super.init();
25+
enableCalcite();
26+
loadIndex(Index.BANK);
27+
loadIndex(Index.BANK_WITH_NULL_VALUES);
28+
loadIndex(Index.OTELLOGS);
29+
}
30+
31+
@Test
32+
public void testChartWithSingleGroupKey() throws IOException {
33+
JSONObject result1 = executeQuery(String.format("source=%s | chart avg(balance) by gender", TEST_INDEX_BANK));
34+
verifySchema(
35+
result1,
36+
schema("avg(balance)", "double"),
37+
schema("gender", "string"));
38+
verifyDataRows(result1, rows(40488, "F"), rows(16377.25, "M"));
39+
JSONObject result2 = executeQuery(String.format("source=%s | chart avg(balance) over gender", TEST_INDEX_BANK));
40+
assertJsonEquals(result1.toString(), result2.toString());
41+
}
42+
43+
@Test
44+
public void testChartWithMultipleGroupKeys() throws IOException {
45+
JSONObject result1 = executeQuery(String.format("source=%s | chart avg(balance) by gender, age", TEST_INDEX_BANK));
46+
verifySchema(
47+
result1,
48+
schema("avg(balance)", "double"),
49+
schema("gender", "string"),
50+
schema("age", "string"));
51+
verifyDataRows(result1, rows(40488, "F", "36"), rows(16377.25, "M", 36));
52+
JSONObject result2 = executeQuery(String.format("source=%s | chart avg(balance) over gender, age", TEST_INDEX_BANK));
53+
assertJsonEquals(result1.toString(), result2.toString());
54+
}
55+
56+
// TODOs:
57+
// Param nullstr: source=opensearch-sql_test_index_bank_with_null_values | eval age = cast(age as string) | chart nullstr='nil' max(account_number) over gender by age
58+
// Param usenull: source=opensearch-sql_test_index_bank_with_null_values | eval age = cast(age as string) | chart usenull=false nullstr='nil' max(account_number) over gender by age
59+
// Param limit = 0: source=bank | chart limit=0 avg(balance) over state by gender
60+
// SPAN:
61+
62+
}

0 commit comments

Comments
 (0)