Skip to content

Commit 990346a

Browse files
authored
Support aggregation/window commands with dynamic fields (#4743)
* Fix aggregation for dynamic fields Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Address comments Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Utilize coercion Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Fix timechart and trendline Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * minor fix Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Minor refactoring Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Add tests for spark sql verification Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Add comment Signed-off-by: Tomoyuki Morita <moritato@amazon.com> * Add explain verification Signed-off-by: Tomoyuki Morita <moritato@amazon.com> --------- Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
1 parent 49cc4b3 commit 990346a

16 files changed

Lines changed: 1484 additions & 103 deletions

File tree

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.common.utils;
7+
8+
import java.util.Collection;
9+
import java.util.Map;
10+
import java.util.stream.Collectors;
11+
12+
/** Utility class for debugging operations. */
13+
public class DebugUtils {
14+
15+
private static void print(String format, Object... args) {
16+
System.out.println(String.format(format, args));
17+
}
18+
19+
public static <T> T debug(T obj, String message) {
20+
print("### %s: %s (at %s)", message, stringify(obj), getCalledFrom(1));
21+
return obj;
22+
}
23+
24+
public static <T> T debug(T obj) {
25+
print("### %s (at %s)", stringify(obj), getCalledFrom(1));
26+
return obj;
27+
}
28+
29+
private static String getCalledFrom(int pos) {
30+
RuntimeException e = new RuntimeException();
31+
StackTraceElement item = e.getStackTrace()[pos + 1];
32+
return item.getClassName() + "." + item.getMethodName() + ":" + item.getLineNumber();
33+
}
34+
35+
private static String stringify(Collection<?> items) {
36+
if (items == null) {
37+
return "null";
38+
}
39+
40+
if (items.isEmpty()) {
41+
return "()";
42+
}
43+
44+
String result = items.stream().map(i -> stringify(i)).collect(Collectors.joining(","));
45+
46+
return "(" + result + ")";
47+
}
48+
49+
private static String stringify(Map<?, ?> map) {
50+
if (map == null) {
51+
return "[[null]]";
52+
}
53+
54+
if (map.isEmpty()) {
55+
return "[[EMPTY]]";
56+
}
57+
58+
String result =
59+
map.entrySet().stream()
60+
.map(entry -> entry.getKey() + ": " + stringify(entry.getValue()))
61+
.collect(Collectors.joining(","));
62+
return "{" + result + "}";
63+
}
64+
65+
private static String stringify(Object obj) {
66+
if (obj instanceof Collection) {
67+
return stringify((Collection) obj);
68+
} else if (obj instanceof Map) {
69+
return stringify((Map) obj);
70+
}
71+
return String.valueOf(obj);
72+
}
73+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.common.utils;
7+
8+
import lombok.experimental.UtilityClass;
9+
10+
@UtilityClass
11+
public class JsonUtils {
12+
/**
13+
* Utility method to build JSON string from multiple strings with single-quotes. This is just for
14+
* ease of read and maintain in tests. sjson("{", "'key': 'name'", "}") -> "{\n \"key\":
15+
* \"name\"\n}"
16+
*
17+
* @param lines lines using single-quote instead for double-quote
18+
* @return sting joined inputs and replaces single-quotes with double-quotes
19+
*/
20+
public static String sjson(String... lines) {
21+
StringBuilder builder = new StringBuilder();
22+
for (String line : lines) {
23+
builder.append(replaceQuote(line));
24+
builder.append("\n");
25+
}
26+
return builder.toString();
27+
}
28+
29+
private static String replaceQuote(String line) {
30+
return line.replace("'", "\"");
31+
}
32+
33+
/**
34+
* Utility method to build multiline string from list of strings. Last line will also have new
35+
* line at the end.
36+
*
37+
* @param lines input lines
38+
* @return string contains lines
39+
*/
40+
public static String lines(String... lines) {
41+
StringBuilder builder = new StringBuilder();
42+
for (String line : lines) {
43+
builder.append(line);
44+
builder.append("\n");
45+
}
46+
return builder.toString();
47+
}
48+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.common.utils;
7+
8+
import static org.junit.Assert.assertEquals;
9+
10+
import org.junit.Test;
11+
12+
public class JsonUtilsTest {
13+
14+
@Test
15+
public void testSjsonWithSingleLine() {
16+
String result = JsonUtils.sjson("{'key': 'value'}");
17+
assertEquals("{\"key\": \"value\"}\n", result);
18+
}
19+
20+
@Test
21+
public void testSjsonWithMultipleLines() {
22+
String result = JsonUtils.sjson("{", " 'name': 'John',", " 'age': 30", "}");
23+
assertEquals("{\n \"name\": \"John\",\n \"age\": 30\n}\n", result);
24+
}
25+
26+
@Test
27+
public void testSjsonWithEmptyString() {
28+
String result = JsonUtils.sjson("");
29+
assertEquals("\n", result);
30+
}
31+
32+
@Test
33+
public void testSjsonWithNoQuotes() {
34+
String result = JsonUtils.sjson("no quotes here");
35+
assertEquals("no quotes here\n", result);
36+
}
37+
38+
@Test
39+
public void testSjsonWithMixedQuotes() {
40+
String result = JsonUtils.sjson("'single' and \"double\" quotes");
41+
assertEquals("\"single\" and \"double\" quotes\n", result);
42+
}
43+
44+
@Test
45+
public void testSjsonWithMultipleSingleQuotes() {
46+
String result = JsonUtils.sjson("'key1': 'value1', 'key2': 'value2'");
47+
assertEquals("\"key1\": \"value1\", \"key2\": \"value2\"\n", result);
48+
}
49+
50+
@Test
51+
public void testSjsonWithNestedJson() {
52+
String result = JsonUtils.sjson("{", " 'outer': {", " 'inner': 'value'", " }", "}");
53+
assertEquals("{\n \"outer\": {\n \"inner\": \"value\"\n }\n}\n", result);
54+
}
55+
56+
@Test
57+
public void testSjsonWithArrays() {
58+
String result = JsonUtils.sjson("{", " 'items': ['item1', 'item2', 'item3']", "}");
59+
assertEquals("{\n \"items\": [\"item1\", \"item2\", \"item3\"]\n}\n", result);
60+
}
61+
62+
@Test
63+
public void testSjsonWithSpecialCharacters() {
64+
String result = JsonUtils.sjson("{'key': 'value with \\'escaped\\' quotes'}");
65+
assertEquals("{\"key\": \"value with \\\"escaped\\\" quotes\"}\n", result);
66+
}
67+
68+
@Test
69+
public void testSjsonWithEmptyArray() {
70+
String result = JsonUtils.sjson();
71+
assertEquals("", result);
72+
}
73+
74+
@Test
75+
public void testSjsonWithNullValues() {
76+
String result = JsonUtils.sjson("{'key': null}");
77+
assertEquals("{\"key\": null}\n", result);
78+
}
79+
80+
@Test
81+
public void testSjsonWithNumbers() {
82+
String result =
83+
JsonUtils.sjson("{", " 'integer': 42,", " 'float': 3.14,", " 'negative': -10", "}");
84+
assertEquals("{\n \"integer\": 42,\n \"float\": 3.14,\n \"negative\": -10\n}\n", result);
85+
}
86+
87+
@Test
88+
public void testSjsonWithBooleans() {
89+
String result = JsonUtils.sjson("{'active': true, 'deleted': false}");
90+
assertEquals("{\"active\": true, \"deleted\": false}\n", result);
91+
}
92+
93+
@Test
94+
public void testSjsonPreservesWhitespace() {
95+
String result = JsonUtils.sjson(" {'key': 'value'} ");
96+
assertEquals(" {\"key\": \"value\"} \n", result);
97+
}
98+
99+
@Test
100+
public void testSjsonWithComplexJson() {
101+
String result =
102+
JsonUtils.sjson(
103+
"{",
104+
" 'user': {",
105+
" 'name': 'Alice',",
106+
" 'email': 'alice@example.com',",
107+
" 'roles': ['admin', 'user'],",
108+
" 'active': true,",
109+
" 'loginCount': 42",
110+
" }",
111+
"}");
112+
String expected =
113+
"{\n"
114+
+ " \"user\": {\n"
115+
+ " \"name\": \"Alice\",\n"
116+
+ " \"email\": \"alice@example.com\",\n"
117+
+ " \"roles\": [\"admin\", \"user\"],\n"
118+
+ " \"active\": true,\n"
119+
+ " \"loginCount\": 42\n"
120+
+ " }\n"
121+
+ "}\n";
122+
assertEquals(expected, result);
123+
}
124+
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.Objects;
3939
import java.util.Optional;
4040
import java.util.Set;
41+
import java.util.function.UnaryOperator;
4142
import java.util.stream.Collectors;
4243
import java.util.stream.Stream;
4344
import org.apache.calcite.plan.RelOptTable;
@@ -1897,10 +1898,10 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
18971898
}
18981899

18991900
// 1. group the group-by list + field list and add a count() aggregation
1900-
List<UnresolvedExpression> groupExprList = new ArrayList<>(node.getGroupExprList());
1901-
List<UnresolvedExpression> fieldList =
1902-
node.getFields().stream().map(f -> (UnresolvedExpression) f).toList();
1903-
groupExprList.addAll(fieldList);
1901+
List<UnresolvedExpression> groupExprList = new ArrayList<>();
1902+
node.getGroupExprList().forEach(exp -> groupExprList.add(exp));
1903+
node.getFields().forEach(field -> groupExprList.add(field));
1904+
groupExprList.forEach(expr -> projectDynamicField(expr, context));
19041905
List<UnresolvedExpression> aggExprList =
19051906
List.of(AstDSL.alias(countFieldName, AstDSL.aggregate("count", null)));
19061907
aggregateWithTrimming(groupExprList, aggExprList, context);
@@ -2048,6 +2049,9 @@ public RelNode visitTimechart(
20482049
org.opensearch.sql.ast.tree.Timechart node, CalcitePlanContext context) {
20492050
visitChildren(node, context);
20502051

2052+
projectDynamicFieldAsString(node.getBinExpression(), context);
2053+
projectDynamicFieldAsString(node.getByField(), context);
2054+
20512055
// Extract parameters
20522056
UnresolvedExpression spanExpr = node.getBinExpression();
20532057

@@ -2111,9 +2115,7 @@ public RelNode visitTimechart(
21112115
List<RexNode> outputFields = context.fieldBuilder.staticFields();
21122116
List<RexNode> reordered = new ArrayList<>();
21132117
reordered.add(context.fieldBuilder.staticField("@timestamp")); // timestamp first
2114-
reordered.add(
2115-
context.fieldBuilder.staticField(
2116-
byFieldName)); // byField second. TODO: allow dynamic fields
2118+
reordered.add(context.fieldBuilder.staticField(byFieldName)); // byField second.
21172119
reordered.add(outputFields.get(outputFields.size() - 1)); // value function last
21182120
context.relBuilder.project(reordered);
21192121

@@ -2145,6 +2147,43 @@ public RelNode visitTimechart(
21452147
}
21462148
}
21472149

2150+
/**
2151+
* Project dynamic field to static field and cast to string to make it easier to handle. It does
2152+
* nothing if exp does not refer dynamic field.
2153+
*/
2154+
private void projectDynamicFieldAsString(UnresolvedExpression exp, CalcitePlanContext context) {
2155+
projectDynamicField(exp, context, node -> context.rexBuilder.castToString(node));
2156+
}
2157+
2158+
/**
2159+
* Project dynamic field to static field to make it easier to handle. It does nothing if exp does
2160+
* not refer dynamic field.
2161+
*/
2162+
private void projectDynamicField(UnresolvedExpression exp, CalcitePlanContext context) {
2163+
UnaryOperator<RexNode> noWrap = node -> node;
2164+
projectDynamicField(exp, context, noWrap);
2165+
}
2166+
2167+
private void projectDynamicField(
2168+
UnresolvedExpression exp, CalcitePlanContext context, UnaryOperator<RexNode> nodeWrapper) {
2169+
if (exp != null) {
2170+
exp.accept(
2171+
new AbstractNodeVisitor<Void, CalcitePlanContext>() {
2172+
@Override
2173+
public Void visitField(Field field, CalcitePlanContext context) {
2174+
RexNode node = rexVisitor.analyze(field, context);
2175+
if (node.isA(SqlKind.ITEM)) {
2176+
RexNode alias =
2177+
context.relBuilder.alias(nodeWrapper.apply(node), field.getField().toString());
2178+
context.relBuilder.projectPlus(alias);
2179+
}
2180+
return null;
2181+
}
2182+
},
2183+
context);
2184+
}
2185+
}
2186+
21482187
/** Build top categories query - simpler approach that works better with OTHER handling */
21492188
private RelNode buildTopCategoriesQuery(
21502189
RelNode completeResults, int limit, CalcitePlanContext context) {
@@ -2349,6 +2388,7 @@ public RelNode visitTrendline(Trendline node, CalcitePlanContext context) {
23492388
.ifPresent(
23502389
sortField -> {
23512390
SortOption sortOption = analyzeSortOption(sortField.getFieldArgs());
2391+
projectDynamicFieldAsString(sortField, context);
23522392
RexNode field = rexVisitor.analyze(sortField, context);
23532393
if (sortOption == DEFAULT_DESC) {
23542394
context.relBuilder.sort(context.relBuilder.desc(field));
@@ -2362,7 +2402,9 @@ public RelNode visitTrendline(Trendline node, CalcitePlanContext context) {
23622402
node.getComputations()
23632403
.forEach(
23642404
trendlineComputation -> {
2405+
projectDynamicField(trendlineComputation.getDataField(), context);
23652406
RexNode field = rexVisitor.analyze(trendlineComputation.getDataField(), context);
2407+
23662408
context.relBuilder.filter(context.relBuilder.isNotNull(field));
23672409

23682410
WindowFrame windowFrame =

core/src/main/java/org/opensearch/sql/calcite/ExtendedRexBuilder.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,11 @@ else if ((SqlTypeUtil.isApproximateNumeric(sourceType) || SqlTypeUtil.isDecimal(
163163
}
164164
return super.makeCast(pos, type, exp, matchNullability, safe, format);
165165
}
166+
167+
/** Cast node to string */
168+
public RexNode castToString(RexNode node) {
169+
RelDataType stringType = getTypeFactory().createSqlType(SqlTypeName.VARCHAR);
170+
RelDataType nullableStringType = getTypeFactory().createTypeWithNullability(stringType, true);
171+
return makeCast(nullableStringType, node, true, true);
172+
}
166173
}

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,15 @@ static RexNode makeOver(
182182
return variance(context, field, partitions, rows, lowerBound, upperBound, false, false);
183183
case ROW_NUMBER:
184184
return withOver(
185-
context.relBuilder.aggregateCall(SqlStdOperatorTable.ROW_NUMBER),
185+
makeAggCall(context, functionName, false, null, List.of()),
186186
partitions,
187187
orderKeys,
188188
true,
189189
lowerBound,
190190
upperBound);
191191
case NTH_VALUE:
192192
return withOver(
193-
context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)),
193+
makeAggCall(context, functionName, false, field, argList.subList(0, 1)),
194194
partitions,
195195
orderKeys,
196196
true,
@@ -215,7 +215,12 @@ private static RexNode sumOver(
215215
RexWindowBound lowerBound,
216216
RexWindowBound upperBound) {
217217
return withOver(
218-
ctx.relBuilder.sum(operation), partitions, List.of(), rows, lowerBound, upperBound);
218+
makeAggCall(ctx, BuiltinFunctionName.SUM, false, operation, List.of()),
219+
partitions,
220+
List.of(),
221+
rows,
222+
lowerBound,
223+
upperBound);
219224
}
220225

221226
private static RexNode countOver(

0 commit comments

Comments
 (0)