|
98 | 98 | import org.opensearch.sql.ast.tree.Append; |
99 | 99 | import org.opensearch.sql.ast.tree.AppendCol; |
100 | 100 | import org.opensearch.sql.ast.tree.Bin; |
| 101 | +import org.opensearch.sql.ast.tree.Chart; |
101 | 102 | import org.opensearch.sql.ast.tree.CloseCursor; |
102 | 103 | import org.opensearch.sql.ast.tree.Dedupe; |
103 | 104 | import org.opensearch.sql.ast.tree.Eval; |
@@ -1023,6 +1024,11 @@ private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation( |
1023 | 1024 |
|
1024 | 1025 | @Override |
1025 | 1026 | public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) { |
| 1027 | + return visitAggregationAndReturnProjection(node, context).getLeft(); |
| 1028 | + } |
| 1029 | + |
| 1030 | + private Pair<RelNode, List<RexNode>> visitAggregationAndReturnProjection( |
| 1031 | + Aggregation node, CalcitePlanContext context) { |
1026 | 1032 | visitChildren(node, context); |
1027 | 1033 |
|
1028 | 1034 | List<UnresolvedExpression> aggExprList = node.getAggExprList(); |
@@ -1100,14 +1106,14 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) { |
1100 | 1106 | aggregationAttributes.getLeft().stream() |
1101 | 1107 | .map(this::extractAliasLiteral) |
1102 | 1108 | .flatMap(Optional::stream) |
1103 | | - .map(ref -> ((RexLiteral) ref).getValueAs(String.class)) |
| 1109 | + .map(ref -> ref.getValueAs(String.class)) |
1104 | 1110 | .map(context.relBuilder::field) |
1105 | 1111 | .map(f -> (RexNode) f) |
1106 | 1112 | .toList(); |
1107 | 1113 | reordered.addAll(aliasedGroupByList); |
1108 | 1114 | context.relBuilder.project(reordered); |
1109 | 1115 |
|
1110 | | - return context.relBuilder.peek(); |
| 1116 | + return Pair.of(context.relBuilder.peek(), reordered); |
1111 | 1117 | } |
1112 | 1118 |
|
1113 | 1119 | private Optional<UnresolvedExpression> getTimeSpanField(UnresolvedExpression expr) { |
@@ -1947,6 +1953,90 @@ private String getValueFunctionName(UnresolvedExpression aggregateFunction) { |
1947 | 1953 | return sb.toString(); |
1948 | 1954 | } |
1949 | 1955 |
|
| 1956 | + @Override |
| 1957 | + public RelNode visitChart(Chart node, CalcitePlanContext context) { |
| 1958 | + visitChildren(node, context); |
| 1959 | + ArgumentMap argMap = ArgumentMap.of(node.getArguments()); |
| 1960 | + List<UnresolvedExpression> groupExprList = new ArrayList<>(); |
| 1961 | + UnresolvedExpression span; |
| 1962 | + if (node.getColumnSplit() instanceof Span && node.getRowSplit() instanceof Span) { |
| 1963 | + throw new UnsupportedOperationException("It is not supported to have two span splits"); |
| 1964 | + } else if (node.getRowSplit() instanceof Span) { |
| 1965 | + if (node.getColumnSplit() != null) { |
| 1966 | + groupExprList.add(node.getColumnSplit()); |
| 1967 | + } |
| 1968 | + span = node.getRowSplit(); |
| 1969 | + } else if (node.getColumnSplit() instanceof Span) { |
| 1970 | + if (node.getRowSplit() != null) { |
| 1971 | + groupExprList.add(node.getRowSplit()); |
| 1972 | + } |
| 1973 | + span = node.getColumnSplit(); |
| 1974 | + } else { |
| 1975 | + groupExprList.addAll( |
| 1976 | + Stream.of(node.getRowSplit(), node.getColumnSplit()).filter(Objects::nonNull).toList()); |
| 1977 | + span = null; |
| 1978 | + } |
| 1979 | + Aggregation aggregation = |
| 1980 | + new Aggregation(node.getAggregationFunctions(), List.of(), groupExprList, span, List.of()); |
| 1981 | + Pair<RelNode, List<RexNode>> aggregated = |
| 1982 | + visitAggregationAndReturnProjection(aggregation, context); |
| 1983 | + // If row or column split does not present or limit equals 0, this is the same as `stats agg |
| 1984 | + // [group by col]` |
| 1985 | + |
| 1986 | + Integer limit = |
| 1987 | + Optional.ofNullable(argMap.get("limit")).map(l -> (Integer) l.getValue()).orElse(10); |
| 1988 | + Boolean top = |
| 1989 | + Optional.ofNullable(argMap.get("top")).map(t -> (Boolean) t.getValue()).orElse(true); |
| 1990 | + if (node.getRowSplit() == null || node.getColumnSplit() == null || Objects.equals(limit, 0)) { |
| 1991 | + return aggregated.getLeft(); |
| 1992 | + } |
| 1993 | + List<RexNode> projected = aggregated.getRight(); |
| 1994 | + String columSplitName = aggregated.getLeft().getRowType().getFieldNames().getLast(); |
| 1995 | + RelBuilder relBuilder = context.relBuilder; |
| 1996 | + // 0: agg; 2: column-split |
| 1997 | + relBuilder.project(relBuilder.field(0), relBuilder.field(2)); |
| 1998 | + relBuilder.filter(relBuilder.isNotNull(relBuilder.field(1))); |
| 1999 | + // 1: column split; 0: agg |
| 2000 | + relBuilder.aggregate( |
| 2001 | + relBuilder.groupKey(relBuilder.field(1)), |
| 2002 | + relBuilder.sum(relBuilder.field(0)).as("__grand_total__")); // results: group key, agg calls |
| 2003 | + RexNode grandTotal = relBuilder.field("__grand_total__"); |
| 2004 | + if (top) { |
| 2005 | + grandTotal = relBuilder.desc(grandTotal); |
| 2006 | + } |
| 2007 | + RexNode rowNum = |
| 2008 | + PlanUtils.makeOver( |
| 2009 | + context, |
| 2010 | + BuiltinFunctionName.ROW_NUMBER, |
| 2011 | + relBuilder.literal(1), |
| 2012 | + List.of(), |
| 2013 | + List.of(), |
| 2014 | + List.of(grandTotal), |
| 2015 | + WindowFrame.toCurrentRow()); |
| 2016 | + relBuilder.projectPlus(relBuilder.alias(rowNum, "__row_number__")); |
| 2017 | + RelNode ranked = relBuilder.build(); |
| 2018 | + |
| 2019 | + relBuilder.push(aggregated.getLeft()); |
| 2020 | + relBuilder.push(ranked); |
| 2021 | + |
| 2022 | + // on column-split = group key |
| 2023 | + relBuilder.join( |
| 2024 | + JoinRelType.INNER, relBuilder.equals(relBuilder.field(2, 0, 2), relBuilder.field(2, 1, 0))); |
| 2025 | + RexNode caseExpr = |
| 2026 | + relBuilder.alias( |
| 2027 | + relBuilder.call( |
| 2028 | + SqlStdOperatorTable.CASE, |
| 2029 | + relBuilder.call( |
| 2030 | + SqlStdOperatorTable.LESS_THAN_OR_EQUAL, |
| 2031 | + relBuilder.field("__row_number__"), |
| 2032 | + relBuilder.literal(limit)), |
| 2033 | + relBuilder.field(2), |
| 2034 | + relBuilder.literal("OTHER")), |
| 2035 | + columSplitName); |
| 2036 | + relBuilder.project(relBuilder.field(0), relBuilder.field(1), caseExpr); |
| 2037 | + return relBuilder.peek(); |
| 2038 | + } |
| 2039 | + |
1950 | 2040 | /** Transforms timechart command into SQL-based operations. */ |
1951 | 2041 | @Override |
1952 | 2042 | public RelNode visitTimechart( |
@@ -2064,7 +2154,6 @@ private RelNode buildTopCategoriesQuery( |
2064 | 2154 | if (limit > 0) { |
2065 | 2155 | context.relBuilder.limit(0, limit); |
2066 | 2156 | } |
2067 | | - |
2068 | 2157 | return context.relBuilder.build(); |
2069 | 2158 | } |
2070 | 2159 |
|
|
0 commit comments