|
59 | 59 | import org.opensearch.sql.ast.expression.ParseMethod; |
60 | 60 | import org.opensearch.sql.ast.expression.UnresolvedExpression; |
61 | 61 | import org.opensearch.sql.ast.expression.WindowFrame; |
| 62 | +import org.opensearch.sql.ast.expression.WindowFrame.FrameType; |
62 | 63 | import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; |
63 | 64 | import org.opensearch.sql.ast.tree.AD; |
64 | 65 | import org.opensearch.sql.ast.tree.Aggregation; |
|
86 | 87 | import org.opensearch.sql.ast.tree.SubqueryAlias; |
87 | 88 | import org.opensearch.sql.ast.tree.TableFunction; |
88 | 89 | import org.opensearch.sql.ast.tree.Trendline; |
| 90 | +import org.opensearch.sql.ast.tree.Trendline.TrendlineType; |
89 | 91 | import org.opensearch.sql.ast.tree.UnresolvedPlan; |
90 | 92 | import org.opensearch.sql.ast.tree.Window; |
91 | 93 | import org.opensearch.sql.calcite.plan.OpenSearchConstants; |
92 | 94 | import org.opensearch.sql.calcite.utils.JoinAndLookupUtils; |
93 | 95 | import org.opensearch.sql.calcite.utils.PlanUtils; |
| 96 | +import org.opensearch.sql.common.utils.StringUtils; |
94 | 97 | import org.opensearch.sql.exception.CalciteUnsupportedException; |
95 | 98 | import org.opensearch.sql.exception.SemanticCheckException; |
96 | 99 | import org.opensearch.sql.expression.function.BuiltinFunctionName; |
@@ -975,6 +978,126 @@ public RelNode visitTableFunction(TableFunction node, CalcitePlanContext context |
975 | 978 |
|
976 | 979 | @Override |
977 | 980 | public RelNode visitTrendline(Trendline node, CalcitePlanContext context) { |
978 | | - throw new CalciteUnsupportedException("Trendline command is unsupported in Calcite"); |
| 981 | + visitChildren(node, context); |
| 982 | + |
| 983 | + node.getSortByField() |
| 984 | + .ifPresent( |
| 985 | + sortField -> { |
| 986 | + SortOption sortOption = analyzeSortOption(sortField.getFieldArgs()); |
| 987 | + RexNode field = rexVisitor.analyze(sortField, context); |
| 988 | + if (sortOption == DEFAULT_DESC) { |
| 989 | + context.relBuilder.sort(context.relBuilder.desc(field)); |
| 990 | + } else { |
| 991 | + context.relBuilder.sort(field); |
| 992 | + } |
| 993 | + }); |
| 994 | + |
| 995 | + List<RexNode> trendlineNodes = new ArrayList<>(); |
| 996 | + List<String> aliases = new ArrayList<>(); |
| 997 | + node.getComputations() |
| 998 | + .forEach( |
| 999 | + trendlineComputation -> { |
| 1000 | + RexNode field = rexVisitor.analyze(trendlineComputation.getDataField(), context); |
| 1001 | + context.relBuilder.filter(context.relBuilder.isNotNull(field)); |
| 1002 | + |
| 1003 | + WindowFrame windowFrame = |
| 1004 | + WindowFrame.of( |
| 1005 | + FrameType.ROWS, |
| 1006 | + StringUtils.format( |
| 1007 | + "%d PRECEDING", trendlineComputation.getNumberOfDataPoints() - 1), |
| 1008 | + "CURRENT ROW"); |
| 1009 | + RexNode countExpr = |
| 1010 | + PlanUtils.makeOver( |
| 1011 | + context, |
| 1012 | + BuiltinFunctionName.COUNT, |
| 1013 | + null, |
| 1014 | + List.of(), |
| 1015 | + List.of(), |
| 1016 | + List.of(), |
| 1017 | + windowFrame); |
| 1018 | + // CASE WHEN count() over (ROWS (windowSize-1) PRECEDING) > windowSize - 1 |
| 1019 | + RexNode whenConditionExpr = |
| 1020 | + PPLFuncImpTable.INSTANCE.resolve( |
| 1021 | + context.rexBuilder, |
| 1022 | + ">", |
| 1023 | + countExpr, |
| 1024 | + context.relBuilder.literal(trendlineComputation.getNumberOfDataPoints() - 1)); |
| 1025 | + |
| 1026 | + RexNode thenExpr; |
| 1027 | + switch (trendlineComputation.getComputationType()) { |
| 1028 | + case TrendlineType.SMA: |
| 1029 | + // THEN avg(field) over (ROWS (windowSize-1) PRECEDING) |
| 1030 | + thenExpr = |
| 1031 | + PlanUtils.makeOver( |
| 1032 | + context, |
| 1033 | + BuiltinFunctionName.AVG, |
| 1034 | + field, |
| 1035 | + List.of(), |
| 1036 | + List.of(), |
| 1037 | + List.of(), |
| 1038 | + windowFrame); |
| 1039 | + break; |
| 1040 | + case TrendlineType.WMA: |
| 1041 | + // THEN wma expression |
| 1042 | + thenExpr = |
| 1043 | + buildWmaRexNode( |
| 1044 | + field, |
| 1045 | + trendlineComputation.getNumberOfDataPoints(), |
| 1046 | + windowFrame, |
| 1047 | + context); |
| 1048 | + break; |
| 1049 | + default: |
| 1050 | + throw new IllegalStateException("Unsupported trendline type"); |
| 1051 | + } |
| 1052 | + |
| 1053 | + // ELSE NULL |
| 1054 | + RexNode elseExpr = context.relBuilder.literal(null); |
| 1055 | + |
| 1056 | + List<RexNode> caseOperands = new ArrayList<>(); |
| 1057 | + caseOperands.add(whenConditionExpr); |
| 1058 | + caseOperands.add(thenExpr); |
| 1059 | + caseOperands.add(elseExpr); |
| 1060 | + RexNode trendlineNode = |
| 1061 | + context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands); |
| 1062 | + trendlineNodes.add(trendlineNode); |
| 1063 | + aliases.add(trendlineComputation.getAlias()); |
| 1064 | + }); |
| 1065 | + |
| 1066 | + projectPlusOverriding(trendlineNodes, aliases, context); |
| 1067 | + return context.relBuilder.peek(); |
| 1068 | + } |
| 1069 | + |
| 1070 | + private RexNode buildWmaRexNode( |
| 1071 | + RexNode field, |
| 1072 | + Integer numberOfDataPoints, |
| 1073 | + WindowFrame windowFrame, |
| 1074 | + CalcitePlanContext context) { |
| 1075 | + |
| 1076 | + // Divisor: 1 + 2 + 3 + ... + windowSize, aka (windowSize * (windowSize + 1) / 2) |
| 1077 | + RexNode divisor = context.relBuilder.literal(numberOfDataPoints * (numberOfDataPoints + 1) / 2); |
| 1078 | + |
| 1079 | + // Divider: 1 * NTH_VALUE(field, 1) + 2 * NTH_VALUE(field, 2) + ... + windowSize * |
| 1080 | + // NTH_VALUE(field, windowSize) |
| 1081 | + RexNode divider = context.relBuilder.literal(0); |
| 1082 | + for (int i = 1; i <= numberOfDataPoints; i++) { |
| 1083 | + RexNode nthValueExpr = |
| 1084 | + PlanUtils.makeOver( |
| 1085 | + context, |
| 1086 | + BuiltinFunctionName.NTH_VALUE, |
| 1087 | + field, |
| 1088 | + List.of(context.relBuilder.literal(i)), |
| 1089 | + List.of(), |
| 1090 | + List.of(), |
| 1091 | + windowFrame); |
| 1092 | + divider = |
| 1093 | + context.relBuilder.call( |
| 1094 | + SqlStdOperatorTable.PLUS, |
| 1095 | + divider, |
| 1096 | + context.relBuilder.call( |
| 1097 | + SqlStdOperatorTable.MULTIPLY, nthValueExpr, context.relBuilder.literal(i))); |
| 1098 | + } |
| 1099 | + // Divider / CAST(Divisor, DOUBLE) |
| 1100 | + return context.relBuilder.call( |
| 1101 | + SqlStdOperatorTable.DIVIDE, divider, context.relBuilder.cast(divisor, SqlTypeName.DOUBLE)); |
979 | 1102 | } |
980 | 1103 | } |
0 commit comments