Skip to content

Commit 4c6448d

Browse files
committed
Merge origin/main into cmd-expand
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
2 parents 5ea6d60 + 1841139 commit 4c6448d

21 files changed

Lines changed: 754 additions & 118 deletions

File tree

core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
6666
}
6767

6868
public enum TrendlineType {
69-
SMA
69+
SMA,
70+
WMA
7071
}
7172
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import org.opensearch.sql.ast.expression.ParseMethod;
6161
import org.opensearch.sql.ast.expression.UnresolvedExpression;
6262
import org.opensearch.sql.ast.expression.WindowFrame;
63+
import org.opensearch.sql.ast.expression.WindowFrame.FrameType;
6364
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
6465
import org.opensearch.sql.ast.tree.AD;
6566
import org.opensearch.sql.ast.tree.Aggregation;
@@ -88,11 +89,13 @@
8889
import org.opensearch.sql.ast.tree.SubqueryAlias;
8990
import org.opensearch.sql.ast.tree.TableFunction;
9091
import org.opensearch.sql.ast.tree.Trendline;
92+
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
9193
import org.opensearch.sql.ast.tree.UnresolvedPlan;
9294
import org.opensearch.sql.ast.tree.Window;
9395
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
9496
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
9597
import org.opensearch.sql.calcite.utils.PlanUtils;
98+
import org.opensearch.sql.common.utils.StringUtils;
9699
import org.opensearch.sql.exception.CalciteUnsupportedException;
97100
import org.opensearch.sql.exception.SemanticCheckException;
98101
import org.opensearch.sql.expression.function.BuiltinFunctionName;
@@ -977,7 +980,127 @@ public RelNode visitTableFunction(TableFunction node, CalcitePlanContext context
977980

978981
@Override
979982
public RelNode visitTrendline(Trendline node, CalcitePlanContext context) {
980-
throw new CalciteUnsupportedException("Trendline command is unsupported in Calcite");
983+
visitChildren(node, context);
984+
985+
node.getSortByField()
986+
.ifPresent(
987+
sortField -> {
988+
SortOption sortOption = analyzeSortOption(sortField.getFieldArgs());
989+
RexNode field = rexVisitor.analyze(sortField, context);
990+
if (sortOption == DEFAULT_DESC) {
991+
context.relBuilder.sort(context.relBuilder.desc(field));
992+
} else {
993+
context.relBuilder.sort(field);
994+
}
995+
});
996+
997+
List<RexNode> trendlineNodes = new ArrayList<>();
998+
List<String> aliases = new ArrayList<>();
999+
node.getComputations()
1000+
.forEach(
1001+
trendlineComputation -> {
1002+
RexNode field = rexVisitor.analyze(trendlineComputation.getDataField(), context);
1003+
context.relBuilder.filter(context.relBuilder.isNotNull(field));
1004+
1005+
WindowFrame windowFrame =
1006+
WindowFrame.of(
1007+
FrameType.ROWS,
1008+
StringUtils.format(
1009+
"%d PRECEDING", trendlineComputation.getNumberOfDataPoints() - 1),
1010+
"CURRENT ROW");
1011+
RexNode countExpr =
1012+
PlanUtils.makeOver(
1013+
context,
1014+
BuiltinFunctionName.COUNT,
1015+
null,
1016+
List.of(),
1017+
List.of(),
1018+
List.of(),
1019+
windowFrame);
1020+
// CASE WHEN count() over (ROWS (windowSize-1) PRECEDING) > windowSize - 1
1021+
RexNode whenConditionExpr =
1022+
PPLFuncImpTable.INSTANCE.resolve(
1023+
context.rexBuilder,
1024+
">",
1025+
countExpr,
1026+
context.relBuilder.literal(trendlineComputation.getNumberOfDataPoints() - 1));
1027+
1028+
RexNode thenExpr;
1029+
switch (trendlineComputation.getComputationType()) {
1030+
case TrendlineType.SMA:
1031+
// THEN avg(field) over (ROWS (windowSize-1) PRECEDING)
1032+
thenExpr =
1033+
PlanUtils.makeOver(
1034+
context,
1035+
BuiltinFunctionName.AVG,
1036+
field,
1037+
List.of(),
1038+
List.of(),
1039+
List.of(),
1040+
windowFrame);
1041+
break;
1042+
case TrendlineType.WMA:
1043+
// THEN wma expression
1044+
thenExpr =
1045+
buildWmaRexNode(
1046+
field,
1047+
trendlineComputation.getNumberOfDataPoints(),
1048+
windowFrame,
1049+
context);
1050+
break;
1051+
default:
1052+
throw new IllegalStateException("Unsupported trendline type");
1053+
}
1054+
1055+
// ELSE NULL
1056+
RexNode elseExpr = context.relBuilder.literal(null);
1057+
1058+
List<RexNode> caseOperands = new ArrayList<>();
1059+
caseOperands.add(whenConditionExpr);
1060+
caseOperands.add(thenExpr);
1061+
caseOperands.add(elseExpr);
1062+
RexNode trendlineNode =
1063+
context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
1064+
trendlineNodes.add(trendlineNode);
1065+
aliases.add(trendlineComputation.getAlias());
1066+
});
1067+
1068+
projectPlusOverriding(trendlineNodes, aliases, context);
1069+
return context.relBuilder.peek();
1070+
}
1071+
1072+
private RexNode buildWmaRexNode(
1073+
RexNode field,
1074+
Integer numberOfDataPoints,
1075+
WindowFrame windowFrame,
1076+
CalcitePlanContext context) {
1077+
1078+
// Divisor: 1 + 2 + 3 + ... + windowSize, aka (windowSize * (windowSize + 1) / 2)
1079+
RexNode divisor = context.relBuilder.literal(numberOfDataPoints * (numberOfDataPoints + 1) / 2);
1080+
1081+
// Divider: 1 * NTH_VALUE(field, 1) + 2 * NTH_VALUE(field, 2) + ... + windowSize *
1082+
// NTH_VALUE(field, windowSize)
1083+
RexNode divider = context.relBuilder.literal(0);
1084+
for (int i = 1; i <= numberOfDataPoints; i++) {
1085+
RexNode nthValueExpr =
1086+
PlanUtils.makeOver(
1087+
context,
1088+
BuiltinFunctionName.NTH_VALUE,
1089+
field,
1090+
List.of(context.relBuilder.literal(i)),
1091+
List.of(),
1092+
List.of(),
1093+
windowFrame);
1094+
divider =
1095+
context.relBuilder.call(
1096+
SqlStdOperatorTable.PLUS,
1097+
divider,
1098+
context.relBuilder.call(
1099+
SqlStdOperatorTable.MULTIPLY, nthValueExpr, context.relBuilder.literal(i)));
1100+
}
1101+
// Divider / CAST(Divisor, DOUBLE)
1102+
return context.relBuilder.call(
1103+
SqlStdOperatorTable.DIVIDE, divider, context.relBuilder.cast(divisor, SqlTypeName.DOUBLE));
9811104
}
9821105

9831106
/**

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@
1010
import static org.apache.calcite.rex.RexWindowBounds.UNBOUNDED_PRECEDING;
1111
import static org.apache.calcite.rex.RexWindowBounds.following;
1212
import static org.apache.calcite.rex.RexWindowBounds.preceding;
13-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_POP_NULLABLE;
14-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_SAMP_NULLABLE;
15-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_POP_NULLABLE;
16-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_SAMP_NULLABLE;
17-
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction;
1813

1914
import com.google.common.collect.ImmutableList;
2015
import java.util.ArrayList;
@@ -25,7 +20,6 @@
2520
import org.apache.calcite.rex.RexVisitorImpl;
2621
import org.apache.calcite.rex.RexWindowBound;
2722
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
28-
import org.apache.calcite.sql.type.ReturnTypes;
2923
import org.apache.calcite.sql.type.SqlTypeName;
3024
import org.apache.calcite.tools.RelBuilder;
3125
import org.opensearch.sql.ast.AbstractNodeVisitor;
@@ -37,9 +31,8 @@
3731
import org.opensearch.sql.ast.tree.Relation;
3832
import org.opensearch.sql.ast.tree.UnresolvedPlan;
3933
import org.opensearch.sql.calcite.CalcitePlanContext;
40-
import org.opensearch.sql.calcite.udf.udaf.PercentileApproxFunction;
41-
import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction;
4234
import org.opensearch.sql.expression.function.BuiltinFunctionName;
35+
import org.opensearch.sql.expression.function.PPLFuncImpTable;
4336

4437
public interface PlanUtils {
4538

@@ -116,6 +109,14 @@ static RexNode makeOver(
116109
true,
117110
lowerBound,
118111
upperBound);
112+
case NTH_VALUE:
113+
return withOver(
114+
context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)),
115+
partitions,
116+
orderKeys,
117+
true,
118+
lowerBound,
119+
upperBound);
119120
default:
120121
return withOver(
121122
makeAggCall(context, functionName, false, field, argList),
@@ -232,56 +233,7 @@ static RelBuilder.AggCall makeAggCall(
232233
boolean distinct,
233234
RexNode field,
234235
List<RexNode> argList) {
235-
switch (functionName) {
236-
case MAX:
237-
return context.relBuilder.max(field);
238-
case MIN:
239-
return context.relBuilder.min(field);
240-
case AVG:
241-
return context.relBuilder.avg(distinct, null, field);
242-
case COUNT:
243-
return context.relBuilder.count(
244-
distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field));
245-
case SUM:
246-
return context.relBuilder.sum(distinct, null, field);
247-
// case MEAN:
248-
// throw new UnsupportedOperationException("MEAN is not supported in PPL");
249-
// case STDDEV:
250-
// return context.relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV,
251-
// field);
252-
case VARSAMP:
253-
return context.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field);
254-
case VARPOP:
255-
return context.relBuilder.aggregateCall(VAR_POP_NULLABLE, field);
256-
case STDDEV_POP:
257-
return context.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field);
258-
case STDDEV_SAMP:
259-
return context.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field);
260-
// case PERCENTILE_APPROX:
261-
// return
262-
// context.relBuilder.aggregateCall(SqlStdOperatorTable.PERCENTILE_CONT, field);
263-
case TAKE:
264-
return TransferUserDefinedAggFunction(
265-
TakeAggFunction.class,
266-
"TAKE",
267-
UserDefinedFunctionUtils.getReturnTypeInferenceForArray(),
268-
List.of(field),
269-
argList,
270-
context.relBuilder);
271-
case PERCENTILE_APPROX:
272-
List<RexNode> newArgList = new ArrayList<>(argList);
273-
newArgList.add(context.rexBuilder.makeFlag(field.getType().getSqlTypeName()));
274-
return TransferUserDefinedAggFunction(
275-
PercentileApproxFunction.class,
276-
"percentile_approx",
277-
ReturnTypes.ARG0_FORCE_NULLABLE,
278-
List.of(field),
279-
newArgList,
280-
context.relBuilder);
281-
default:
282-
throw new UnsupportedOperationException(
283-
"Unexpected aggregation: " + functionName.getName().getFunctionName());
284-
}
236+
return PPLFuncImpTable.INSTANCE.resolveAgg(functionName, distinct, field, argList, context);
285237
}
286238

287239
/** Get all uniq input references from a RexNode. */

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public static RelBuilder.AggCall TransferUserDefinedAggFunction(
7676
return relBuilder.aggregateCall(sqlUDAF, addArgList);
7777
}
7878

79-
static SqlReturnTypeInference getReturnTypeInferenceForArray() {
79+
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {
8080
return opBinding -> {
8181
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
8282

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ public enum BuiltinFunctionName {
195195
TAKE(FunctionName.of("take")),
196196
// t-digest percentile which is used in OpenSearch core by default.
197197
PERCENTILE_APPROX(FunctionName.of("percentile_approx")),
198+
DISTINCT_COUNT_APPROX(FunctionName.of("distinct_count_approx")),
198199
// Not always an aggregation query
199200
NESTED(FunctionName.of("nested")),
200201

@@ -252,6 +253,7 @@ public enum BuiltinFunctionName {
252253
IS_BLANK(FunctionName.of("isblank")),
253254

254255
ROW_NUMBER(FunctionName.of("row_number")),
256+
NTH_VALUE(FunctionName.of("nth_value")),
255257
RANK(FunctionName.of("rank")),
256258
DENSE_RANK(FunctionName.of("dense_rank")),
257259

@@ -336,6 +338,7 @@ public enum BuiltinFunctionName {
336338
.put("take", BuiltinFunctionName.TAKE)
337339
.put("percentile", BuiltinFunctionName.PERCENTILE_APPROX)
338340
.put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX)
341+
.put("distinct_count_approx", BuiltinFunctionName.DISTINCT_COUNT_APPROX)
339342
.build();
340343

341344
private static final Map<String, BuiltinFunctionName> WINDOW_FUNC_MAPPING =

0 commit comments

Comments
 (0)