Skip to content

Commit a696618

Browse files
songkant-awsxinyual
authored andcommitted
Support trendline command in Calcite (opensearch-project#3741)
* Support trendline command in Calcite Signed-off-by: Songkan Tang <songkant@amazon.com> * Fix CalciteExplainIT for trendline command Signed-off-by: Songkan Tang <songkant@amazon.com> * Update trendline.rst doc to callout new wma algorithm supported by Calcite Signed-off-by: Songkan Tang <songkant@amazon.com> * Fix typo in the doctest and rephrase the doc and formula Signed-off-by: Songkan Tang <songkant@amazon.com> * Add missing trendline pushdown IT Signed-off-by: Songkan Tang <songkant@amazon.com> * Remove unexpected change Signed-off-by: Songkan Tang <songkant@amazon.com> --------- Signed-off-by: Songkan Tang <songkant@amazon.com>
1 parent 9896166 commit a696618

14 files changed

Lines changed: 448 additions & 42 deletions

File tree

core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
6666
}
6767

6868
public enum TrendlineType {
69-
SMA
69+
SMA,
70+
WMA
7071
}
7172
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.opensearch.sql.ast.expression.ParseMethod;
6060
import org.opensearch.sql.ast.expression.UnresolvedExpression;
6161
import org.opensearch.sql.ast.expression.WindowFrame;
62+
import org.opensearch.sql.ast.expression.WindowFrame.FrameType;
6263
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
6364
import org.opensearch.sql.ast.tree.AD;
6465
import org.opensearch.sql.ast.tree.Aggregation;
@@ -86,11 +87,13 @@
8687
import org.opensearch.sql.ast.tree.SubqueryAlias;
8788
import org.opensearch.sql.ast.tree.TableFunction;
8889
import org.opensearch.sql.ast.tree.Trendline;
90+
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
8991
import org.opensearch.sql.ast.tree.UnresolvedPlan;
9092
import org.opensearch.sql.ast.tree.Window;
9193
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
9294
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
9395
import org.opensearch.sql.calcite.utils.PlanUtils;
96+
import org.opensearch.sql.common.utils.StringUtils;
9497
import org.opensearch.sql.exception.CalciteUnsupportedException;
9598
import org.opensearch.sql.exception.SemanticCheckException;
9699
import org.opensearch.sql.expression.function.BuiltinFunctionName;
@@ -979,6 +982,126 @@ public RelNode visitTableFunction(TableFunction node, CalcitePlanContext context
979982

980983
@Override
981984
public RelNode visitTrendline(Trendline node, CalcitePlanContext context) {
982-
throw new CalciteUnsupportedException("Trendline command is unsupported in Calcite");
985+
visitChildren(node, context);
986+
987+
node.getSortByField()
988+
.ifPresent(
989+
sortField -> {
990+
SortOption sortOption = analyzeSortOption(sortField.getFieldArgs());
991+
RexNode field = rexVisitor.analyze(sortField, context);
992+
if (sortOption == DEFAULT_DESC) {
993+
context.relBuilder.sort(context.relBuilder.desc(field));
994+
} else {
995+
context.relBuilder.sort(field);
996+
}
997+
});
998+
999+
List<RexNode> trendlineNodes = new ArrayList<>();
1000+
List<String> aliases = new ArrayList<>();
1001+
node.getComputations()
1002+
.forEach(
1003+
trendlineComputation -> {
1004+
RexNode field = rexVisitor.analyze(trendlineComputation.getDataField(), context);
1005+
context.relBuilder.filter(context.relBuilder.isNotNull(field));
1006+
1007+
WindowFrame windowFrame =
1008+
WindowFrame.of(
1009+
FrameType.ROWS,
1010+
StringUtils.format(
1011+
"%d PRECEDING", trendlineComputation.getNumberOfDataPoints() - 1),
1012+
"CURRENT ROW");
1013+
RexNode countExpr =
1014+
PlanUtils.makeOver(
1015+
context,
1016+
BuiltinFunctionName.COUNT,
1017+
null,
1018+
List.of(),
1019+
List.of(),
1020+
List.of(),
1021+
windowFrame);
1022+
// CASE WHEN count() over (ROWS (windowSize-1) PRECEDING) > windowSize - 1
1023+
RexNode whenConditionExpr =
1024+
PPLFuncImpTable.INSTANCE.resolve(
1025+
context.rexBuilder,
1026+
">",
1027+
countExpr,
1028+
context.relBuilder.literal(trendlineComputation.getNumberOfDataPoints() - 1));
1029+
1030+
RexNode thenExpr;
1031+
switch (trendlineComputation.getComputationType()) {
1032+
case TrendlineType.SMA:
1033+
// THEN avg(field) over (ROWS (windowSize-1) PRECEDING)
1034+
thenExpr =
1035+
PlanUtils.makeOver(
1036+
context,
1037+
BuiltinFunctionName.AVG,
1038+
field,
1039+
List.of(),
1040+
List.of(),
1041+
List.of(),
1042+
windowFrame);
1043+
break;
1044+
case TrendlineType.WMA:
1045+
// THEN wma expression
1046+
thenExpr =
1047+
buildWmaRexNode(
1048+
field,
1049+
trendlineComputation.getNumberOfDataPoints(),
1050+
windowFrame,
1051+
context);
1052+
break;
1053+
default:
1054+
throw new IllegalStateException("Unsupported trendline type");
1055+
}
1056+
1057+
// ELSE NULL
1058+
RexNode elseExpr = context.relBuilder.literal(null);
1059+
1060+
List<RexNode> caseOperands = new ArrayList<>();
1061+
caseOperands.add(whenConditionExpr);
1062+
caseOperands.add(thenExpr);
1063+
caseOperands.add(elseExpr);
1064+
RexNode trendlineNode =
1065+
context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
1066+
trendlineNodes.add(trendlineNode);
1067+
aliases.add(trendlineComputation.getAlias());
1068+
});
1069+
1070+
projectPlusOverriding(trendlineNodes, aliases, context);
1071+
return context.relBuilder.peek();
1072+
}
1073+
1074+
private RexNode buildWmaRexNode(
1075+
RexNode field,
1076+
Integer numberOfDataPoints,
1077+
WindowFrame windowFrame,
1078+
CalcitePlanContext context) {
1079+
1080+
// Divisor: 1 + 2 + 3 + ... + windowSize, aka (windowSize * (windowSize + 1) / 2)
1081+
RexNode divisor = context.relBuilder.literal(numberOfDataPoints * (numberOfDataPoints + 1) / 2);
1082+
1083+
// Divider: 1 * NTH_VALUE(field, 1) + 2 * NTH_VALUE(field, 2) + ... + windowSize *
1084+
// NTH_VALUE(field, windowSize)
1085+
RexNode divider = context.relBuilder.literal(0);
1086+
for (int i = 1; i <= numberOfDataPoints; i++) {
1087+
RexNode nthValueExpr =
1088+
PlanUtils.makeOver(
1089+
context,
1090+
BuiltinFunctionName.NTH_VALUE,
1091+
field,
1092+
List.of(context.relBuilder.literal(i)),
1093+
List.of(),
1094+
List.of(),
1095+
windowFrame);
1096+
divider =
1097+
context.relBuilder.call(
1098+
SqlStdOperatorTable.PLUS,
1099+
divider,
1100+
context.relBuilder.call(
1101+
SqlStdOperatorTable.MULTIPLY, nthValueExpr, context.relBuilder.literal(i)));
1102+
}
1103+
// Divider / CAST(Divisor, DOUBLE)
1104+
return context.relBuilder.call(
1105+
SqlStdOperatorTable.DIVIDE, divider, context.relBuilder.cast(divisor, SqlTypeName.DOUBLE));
9831106
}
9841107
}

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ static RexNode makeOver(
141141
true,
142142
lowerBound,
143143
upperBound);
144+
case NTH_VALUE:
145+
return withOver(
146+
context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)),
147+
partitions,
148+
orderKeys,
149+
true,
150+
lowerBound,
151+
upperBound);
144152
default:
145153
return withOver(
146154
makeAggCall(context, functionName, false, field, argList),

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ public enum BuiltinFunctionName {
249249
IS_BLANK(FunctionName.of("isblank")),
250250

251251
ROW_NUMBER(FunctionName.of("row_number")),
252+
NTH_VALUE(FunctionName.of("nth_value")),
252253
RANK(FunctionName.of("rank")),
253254
DENSE_RANK(FunctionName.of("dense_rank")),
254255

docs/user/ppl/cmd/trendline.rst

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,36 @@ Description
1515
1616
Syntax
1717
============
18-
`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...`
18+
`TRENDLINE [sort <[+|-] sort-field>] [SMA|WMA](number-of-datapoints, field) [AS alias] [[SMA|WMA](number-of-datapoints, field) [AS alias]]...`
1919

2020
* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first.
2121
* sort-field: mandatory when sorting is used. The field used to sort.
2222
* number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero).
2323
* field: mandatory. The name of the field the moving average should be calculated for.
2424
* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline").
2525

26-
At the moment only the Simple Moving Average (SMA) type is supported.
26+
Starting with version 3.1.0, two trendline algorithms are supported, aka Simple Moving Average (SMA) and Weighted Moving Average (WMA).
2727

28-
It is calculated like
28+
Suppose:
2929

30-
f[i]: The value of field 'f' in the i-th data-point
31-
n: The number of data-points in the moving window (period)
32-
t: The current time index
30+
* f[i]: The value of field 'f' in the i-th data-point
31+
* n: The number of data-points in the moving window (period)
32+
* t: The current time index
33+
34+
SMA is calculated like
3335

3436
SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t
3537

36-
Example 1: Calculate the moving average on one field.
38+
WMA places more weights on recent values compared to equal-weighted SMA algorithm
39+
40+
WMA(t) = (1/(1 + 2 + ... + n)) * Σ(1 * f[i-n+1] + 2 * f[t-n+2] + ... + n * f[t])
41+
= (2/(n * (n + 1))) * Σ((i - t + n) * f[i]), where i = t-n+1 to t
42+
43+
44+
Example 1: Calculate the simple moving average on one field.
3745
=====================================================
3846

39-
The example shows how to calculate the moving average on one field.
47+
The example shows how to calculate the simple moving average on one field.
4048

4149
PPL query::
4250

@@ -52,10 +60,10 @@ PPL query::
5260
+------+
5361

5462

55-
Example 2: Calculate the moving average on multiple fields.
63+
Example 2: Calculate the simple moving average on multiple fields.
5664
===========================================================
5765

58-
The example shows how to calculate the moving average on multiple fields.
66+
The example shows how to calculate the simple moving average on multiple fields.
5967

6068
PPL query::
6169

@@ -70,10 +78,10 @@ PPL query::
7078
| 15.5 | 30.5 |
7179
+------+-----------+
7280

73-
Example 4: Calculate the moving average on one field without specifying an alias.
81+
Example 3: Calculate the simple moving average on one field without specifying an alias.
7482
=================================================================================
7583

76-
The example shows how to calculate the moving average on one field.
84+
The example shows how to calculate the simple moving average on one field.
7785

7886
PPL query::
7987

@@ -88,3 +96,40 @@ PPL query::
8896
| 15.5 |
8997
+--------------------------+
9098

99+
Example 4: Calculate the weighted moving average on one field.
100+
=================================================================================
101+
102+
Version
103+
-------
104+
3.1.0
105+
106+
Configuration
107+
-------------
108+
wma algorithm requires Calcite enabled.
109+
110+
Enable Calcite:
111+
112+
>> curl -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings -d '{
113+
"persistent" : {
114+
"plugins.calcite.enabled" : true
115+
}
116+
}'
117+
118+
The example shows how to calculate the weighted moving average on one field.
119+
120+
PPL query::
121+
122+
PPL> source=accounts | trendline wma(2, account_number) | fields account_number_trendline;
123+
fetched rows / total rows = 4/4
124+
+--------------------------+
125+
| account_number_trendline |
126+
|--------------------------|
127+
| null |
128+
| 4.333333333333333 |
129+
| 10.666666666666666 |
130+
| 16.333333333333332 |
131+
+--------------------------+
132+
133+
Limitation
134+
==========
135+
Starting with version 3.1.0, the ``trendline`` command requires all values in the specified ``field`` to be non-null. Any rows with null values present in the calculation field will be automatically excluded from the command's output.

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,6 @@ public void init() throws Exception {
1717
disallowCalciteFallback();
1818
}
1919

20-
@Override
21-
public void testTrendlinePushDownExplain() throws Exception {
22-
withFallbackEnabled(
23-
() -> {
24-
try {
25-
super.testTrendlinePushDownExplain();
26-
} catch (Exception e) {
27-
throw new RuntimeException(e);
28-
}
29-
},
30-
"https://github.com/opensearch-project/sql/issues/3466");
31-
}
32-
33-
@Override
34-
public void testTrendlineWithSortPushDownExplain() throws Exception {
35-
withFallbackEnabled(
36-
() -> {
37-
try {
38-
super.testTrendlineWithSortPushDownExplain();
39-
} catch (Exception e) {
40-
throw new RuntimeException(e);
41-
}
42-
},
43-
"https://github.com/opensearch-project/sql/issues/3466");
44-
}
45-
4620
@Override
4721
@Ignore("test only in v2")
4822
public void testExplainModeUnsupportedInV2() throws IOException {}

0 commit comments

Comments
 (0)