|
11 | 11 | import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC; |
12 | 12 | import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; |
13 | 13 | import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; |
| 14 | +import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME; |
14 | 15 |
|
15 | 16 | import com.google.common.base.Strings; |
16 | 17 | import com.google.common.collect.ImmutableList; |
|
44 | 45 | import org.opensearch.sql.ast.expression.AllFields; |
45 | 46 | import org.opensearch.sql.ast.expression.AllFieldsExcludeMeta; |
46 | 47 | import org.opensearch.sql.ast.expression.Argument; |
| 48 | +import org.opensearch.sql.ast.expression.Argument.ArgumentMap; |
47 | 49 | import org.opensearch.sql.ast.expression.Field; |
48 | 50 | import org.opensearch.sql.ast.expression.Let; |
49 | 51 | import org.opensearch.sql.ast.expression.Literal; |
50 | 52 | import org.opensearch.sql.ast.expression.ParseMethod; |
51 | 53 | import org.opensearch.sql.ast.expression.UnresolvedExpression; |
| 54 | +import org.opensearch.sql.ast.expression.WindowFrame; |
52 | 55 | import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; |
53 | 56 | import org.opensearch.sql.ast.tree.AD; |
54 | 57 | import org.opensearch.sql.ast.tree.Aggregation; |
|
82 | 85 | import org.opensearch.sql.calcite.utils.PlanUtils; |
83 | 86 | import org.opensearch.sql.exception.CalciteUnsupportedException; |
84 | 87 | import org.opensearch.sql.exception.SemanticCheckException; |
| 88 | +import org.opensearch.sql.expression.function.BuiltinFunctionName; |
85 | 89 | import org.opensearch.sql.expression.function.PPLFuncImpTable; |
86 | 90 | import org.opensearch.sql.utils.ParseUtils; |
87 | 91 |
|
@@ -708,9 +712,77 @@ public RelNode visitFillNull(FillNull fillNull, CalcitePlanContext context) { |
708 | 712 | throw new CalciteUnsupportedException("FillNull command is unsupported in Calcite"); |
709 | 713 | } |
710 | 714 |
|
| 715 | + List<RexNode> resolveGroupByPlusFieldList(RareTopN node, CalcitePlanContext context) { |
| 716 | + List<RexNode> groupByList = resolveGroupByList(node, context); |
| 717 | + List<RexNode> filedsList = |
| 718 | + node.getFields().stream().map(g -> rexVisitor.analyze(g, context)).toList(); |
| 719 | + List<RexNode> all = new ArrayList<>(groupByList); |
| 720 | + all.addAll(filedsList); |
| 721 | + return all; |
| 722 | + } |
| 723 | + |
| 724 | + List<RexNode> resolveGroupByList(RareTopN node, CalcitePlanContext context) { |
| 725 | + return node.getGroupExprList().stream().map(g -> rexVisitor.analyze(g, context)).toList(); |
| 726 | + } |
| 727 | + |
711 | 728 | @Override |
712 | 729 | public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) { |
713 | | - throw new CalciteUnsupportedException("Rare and Top commands are unsupported in Calcite"); |
| 730 | + visitChildren(node, context); |
| 731 | + |
| 732 | + // 1. before aggregating, add a trim project |
| 733 | + List<RexInputRef> trimmedRefs = |
| 734 | + PlanUtils.getInputRefs(resolveGroupByPlusFieldList(node, context)); |
| 735 | + context.relBuilder.project(trimmedRefs); |
| 736 | + |
| 737 | + ArgumentMap arguments = ArgumentMap.of(node.getArguments()); |
| 738 | + // 2. group the group-by list + field list and add a count() aggregation |
| 739 | + List<RexNode> firstGroupBy = resolveGroupByPlusFieldList(node, context); |
| 740 | + String countFieldName = (String) arguments.get("countField").getValue(); |
| 741 | + if (context.relBuilder.peek().getRowType().getFieldNames().contains(countFieldName)) { |
| 742 | + throw new IllegalArgumentException( |
| 743 | + "Field `" |
| 744 | + + countFieldName |
| 745 | + + "` is existed, change the count field by setting countfield='xyz'"); |
| 746 | + } |
| 747 | + AggCall aggCall = |
| 748 | + PlanUtils.makeAggCall(context, BuiltinFunctionName.COUNT, false, null, List.of()) |
| 749 | + .as(countFieldName); |
| 750 | + context.relBuilder.aggregate(context.relBuilder.groupKey(firstGroupBy), aggCall); |
| 751 | + |
| 752 | + // 3. add a window column |
| 753 | + List<RexNode> partitionKeys = new ArrayList<>(resolveGroupByList(node, context)); |
| 754 | + RexNode countField; |
| 755 | + if (node.getCommandType() == RareTopN.CommandType.TOP) { |
| 756 | + countField = context.relBuilder.desc(context.relBuilder.field(countFieldName)); |
| 757 | + } else { |
| 758 | + countField = context.relBuilder.field(countFieldName); |
| 759 | + } |
| 760 | + RexNode rowNumberWindowOver = |
| 761 | + PlanUtils.makeOver( |
| 762 | + context, |
| 763 | + BuiltinFunctionName.ROW_NUMBER, |
| 764 | + null, |
| 765 | + List.of(), |
| 766 | + partitionKeys, |
| 767 | + List.of(countField), |
| 768 | + WindowFrame.toCurrentRow()); |
| 769 | + context.relBuilder.projectPlus( |
| 770 | + context.relBuilder.alias(rowNumberWindowOver, ROW_NUMBER_COLUMN_NAME)); |
| 771 | + |
| 772 | + // 4. filter row_number() <= k in each partition |
| 773 | + Integer N = (Integer) arguments.get("noOfResults").getValue(); |
| 774 | + context.relBuilder.filter( |
| 775 | + context.relBuilder.lessThanOrEqual( |
| 776 | + context.relBuilder.field(ROW_NUMBER_COLUMN_NAME), context.relBuilder.literal(N))); |
| 777 | + |
| 778 | + // 5. project final output. the default output is group by list + field list |
| 779 | + List<RexNode> finalProjectList = new ArrayList<>(resolveGroupByPlusFieldList(node, context)); |
| 780 | + Boolean showCount = (Boolean) arguments.get("showCount").getValue(); |
| 781 | + if (showCount) { |
| 782 | + finalProjectList.add(context.relBuilder.field(countFieldName)); |
| 783 | + } |
| 784 | + context.relBuilder.project(finalProjectList); |
| 785 | + return context.relBuilder.peek(); |
714 | 786 | } |
715 | 787 |
|
716 | 788 | @Override |
|
0 commit comments