Skip to content

Commit 933ec4e

Browse files
committed
Support top, rare commands with Calcite
Signed-off-by: Lantao Jin <ltjin@amazon.com>
1 parent 94fb171 commit 933ec4e

24 files changed

Lines changed: 628 additions & 66 deletions

File tree

core/src/main/java/org/opensearch/sql/analysis/Analyzer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) {
382382
fields.forEach(
383383
field -> newEnv.define(new Symbol(Namespace.FIELD_NAME, field.toString()), field.type()));
384384

385-
List<Argument> options = node.getNoOfResults();
385+
List<Argument> options = node.getArguments();
386386
Integer noOfResults = (Integer) options.get(0).getValue().getValue();
387387

388388
return new LogicalRareTopN(child, node.getCommandType(), noOfResults, fields, groupBys);

core/src/main/java/org/opensearch/sql/ast/expression/Argument.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import java.util.Arrays;
99
import java.util.List;
10+
import java.util.Map;
1011
import lombok.EqualsAndHashCode;
1112
import lombok.Getter;
1213
import lombok.RequiredArgsConstructor;
@@ -32,4 +33,22 @@ public List<UnresolvedExpression> getChild() {
3233
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
3334
return nodeVisitor.visitArgument(this, context);
3435
}
36+
37+
public static class ArgumentMap {
38+
private final Map<String, Literal> map;
39+
40+
public ArgumentMap(List<Argument> arguments) {
41+
this.map =
42+
arguments.stream()
43+
.collect(java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue));
44+
}
45+
46+
public static ArgumentMap of(List<Argument> arguments) {
47+
return new ArgumentMap(arguments);
48+
}
49+
50+
public Literal get(String name) {
51+
return map.get(name);
52+
}
53+
}
3554
}

core/src/main/java/org/opensearch/sql/ast/expression/WindowFrame.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import lombok.Getter;
1111
import lombok.RequiredArgsConstructor;
1212
import lombok.ToString;
13+
import org.opensearch.sql.ast.dsl.AstDSL;
1314

1415
@EqualsAndHashCode(callSuper = false)
1516
@Getter
@@ -25,12 +26,25 @@ public enum FrameType {
2526
ROWS
2627
}
2728

28-
public static WindowFrame defaultFrame() {
29-
return new WindowFrame(
30-
FrameType.ROWS, createBound("UNBOUNDED PRECEDING"), createBound("UNBOUNDED FOLLOWING"));
29+
public static WindowFrame rowsUnbounded() {
30+
return WindowFrame.of(
31+
FrameType.ROWS,
32+
AstDSL.stringLiteral("UNBOUNDED PRECEDING"),
33+
AstDSL.stringLiteral("UNBOUNDED FOLLOWING"));
3134
}
3235

33-
public static WindowFrame create(FrameType type, Literal lower, Literal upper) {
36+
public static WindowFrame toCurrentRow() {
37+
return WindowFrame.of(
38+
FrameType.ROWS,
39+
AstDSL.stringLiteral("UNBOUNDED PRECEDING"),
40+
AstDSL.stringLiteral("CURRENT ROW"));
41+
}
42+
43+
public static WindowFrame of(FrameType type, String lower, String upper) {
44+
return WindowFrame.of(type, AstDSL.stringLiteral(lower), AstDSL.stringLiteral(upper));
45+
}
46+
47+
public static WindowFrame of(FrameType type, Literal lower, Literal upper) {
3448
WindowBound lowerBound = null;
3549
WindowBound upperBound = null;
3650
if (lower != null) {

core/src/main/java/org/opensearch/sql/ast/expression/WindowFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class WindowFunction extends UnresolvedExpression {
2626
private final UnresolvedExpression function;
2727
@Setter private List<UnresolvedExpression> partitionByList = new ArrayList<>();
2828
@Setter private List<Pair<SortOption, UnresolvedExpression>> sortList = new ArrayList<>();
29-
@Setter private WindowFrame windowFrame = WindowFrame.defaultFrame();
29+
@Setter private WindowFrame windowFrame = WindowFrame.rowsUnbounded();
3030

3131
public WindowFunction(
3232
UnresolvedExpression function,

core/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ public class RareTopN extends UnresolvedPlan {
2929

3030
private UnresolvedPlan child;
3131
private final CommandType commandType;
32-
private final List<Argument> noOfResults;
32+
// arguments: noOfResults: Integer, countField: String, showCount: Boolean
33+
private final List<Argument> arguments;
3334
private final List<Field> fields;
3435
private final List<UnresolvedExpression> groupExprList;
3536

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC;
1212
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
1313
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
14+
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME;
1415

1516
import com.google.common.base.Strings;
1617
import com.google.common.collect.ImmutableList;
@@ -44,11 +45,13 @@
4445
import org.opensearch.sql.ast.expression.AllFields;
4546
import org.opensearch.sql.ast.expression.AllFieldsExcludeMeta;
4647
import org.opensearch.sql.ast.expression.Argument;
48+
import org.opensearch.sql.ast.expression.Argument.ArgumentMap;
4749
import org.opensearch.sql.ast.expression.Field;
4850
import org.opensearch.sql.ast.expression.Let;
4951
import org.opensearch.sql.ast.expression.Literal;
5052
import org.opensearch.sql.ast.expression.ParseMethod;
5153
import org.opensearch.sql.ast.expression.UnresolvedExpression;
54+
import org.opensearch.sql.ast.expression.WindowFrame;
5255
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
5356
import org.opensearch.sql.ast.tree.AD;
5457
import org.opensearch.sql.ast.tree.Aggregation;
@@ -82,6 +85,7 @@
8285
import org.opensearch.sql.calcite.utils.PlanUtils;
8386
import org.opensearch.sql.exception.CalciteUnsupportedException;
8487
import org.opensearch.sql.exception.SemanticCheckException;
88+
import org.opensearch.sql.expression.function.BuiltinFunctionName;
8589
import org.opensearch.sql.expression.function.PPLFuncImpTable;
8690
import org.opensearch.sql.utils.ParseUtils;
8791

@@ -708,9 +712,77 @@ public RelNode visitFillNull(FillNull fillNull, CalcitePlanContext context) {
708712
throw new CalciteUnsupportedException("FillNull command is unsupported in Calcite");
709713
}
710714

715+
List<RexNode> resolveGroupByPlusFieldList(RareTopN node, CalcitePlanContext context) {
716+
List<RexNode> groupByList = resolveGroupByList(node, context);
717+
List<RexNode> filedsList =
718+
node.getFields().stream().map(g -> rexVisitor.analyze(g, context)).toList();
719+
List<RexNode> all = new ArrayList<>(groupByList);
720+
all.addAll(filedsList);
721+
return all;
722+
}
723+
724+
List<RexNode> resolveGroupByList(RareTopN node, CalcitePlanContext context) {
725+
return node.getGroupExprList().stream().map(g -> rexVisitor.analyze(g, context)).toList();
726+
}
727+
711728
@Override
712729
public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
713-
throw new CalciteUnsupportedException("Rare and Top commands are unsupported in Calcite");
730+
visitChildren(node, context);
731+
732+
// 1. before aggregating, add a trim project
733+
List<RexInputRef> trimmedRefs =
734+
PlanUtils.getInputRefs(resolveGroupByPlusFieldList(node, context));
735+
context.relBuilder.project(trimmedRefs);
736+
737+
ArgumentMap arguments = ArgumentMap.of(node.getArguments());
738+
// 2. group the group-by list + field list and add a count() aggregation
739+
List<RexNode> firstGroupBy = resolveGroupByPlusFieldList(node, context);
740+
String countFieldName = (String) arguments.get("countField").getValue();
741+
if (context.relBuilder.peek().getRowType().getFieldNames().contains(countFieldName)) {
742+
throw new IllegalArgumentException(
743+
"Field `"
744+
+ countFieldName
745+
+ "` is existed, change the count field by setting countfield='xyz'");
746+
}
747+
AggCall aggCall =
748+
PlanUtils.makeAggCall(context, BuiltinFunctionName.COUNT, false, null, List.of())
749+
.as(countFieldName);
750+
context.relBuilder.aggregate(context.relBuilder.groupKey(firstGroupBy), aggCall);
751+
752+
// 3. add a window column
753+
List<RexNode> partitionKeys = new ArrayList<>(resolveGroupByList(node, context));
754+
RexNode countField;
755+
if (node.getCommandType() == RareTopN.CommandType.TOP) {
756+
countField = context.relBuilder.desc(context.relBuilder.field(countFieldName));
757+
} else {
758+
countField = context.relBuilder.field(countFieldName);
759+
}
760+
RexNode rowNumberWindowOver =
761+
PlanUtils.makeOver(
762+
context,
763+
BuiltinFunctionName.ROW_NUMBER,
764+
null,
765+
List.of(),
766+
partitionKeys,
767+
List.of(countField),
768+
WindowFrame.toCurrentRow());
769+
context.relBuilder.projectPlus(
770+
context.relBuilder.alias(rowNumberWindowOver, ROW_NUMBER_COLUMN_NAME));
771+
772+
// 4. filter row_number() <= k in each partition
773+
Integer N = (Integer) arguments.get("noOfResults").getValue();
774+
context.relBuilder.filter(
775+
context.relBuilder.lessThanOrEqual(
776+
context.relBuilder.field(ROW_NUMBER_COLUMN_NAME), context.relBuilder.literal(N)));
777+
778+
// 5. project final output. the default output is group by list + field list
779+
List<RexNode> finalProjectList = new ArrayList<>(resolveGroupByPlusFieldList(node, context));
780+
Boolean showCount = (Boolean) arguments.get("showCount").getValue();
781+
if (showCount) {
782+
finalProjectList.add(context.relBuilder.field(countFieldName));
783+
}
784+
context.relBuilder.project(finalProjectList);
785+
return context.relBuilder.peek();
714786
}
715787

716788
@Override

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte
365365
? Collections.emptyList()
366366
: arguments.subList(1, arguments.size());
367367
return PlanUtils.makeOver(
368-
context, functionName, field, args, partitions, node.getWindowFrame());
368+
context, functionName, field, args, partitions, List.of(), node.getWindowFrame());
369369
})
370370
.orElseThrow(
371371
() ->

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
public interface PlanUtils {
4141

42+
String ROW_NUMBER_COLUMN_NAME = "_row_number_";
43+
4244
static SpanUnit intervalUnitToSpanUnit(IntervalUnit unit) {
4345
return switch (unit) {
4446
case MICROSECOND -> SpanUnit.MILLISECOND;
@@ -61,9 +63,10 @@ static RexNode makeOver(
6163
RexNode field,
6264
List<RexNode> argList,
6365
List<RexNode> partitions,
66+
List<RexNode> orderKeys,
6467
@Nullable WindowFrame windowFrame) {
6568
if (windowFrame == null) {
66-
windowFrame = WindowFrame.defaultFrame();
69+
windowFrame = WindowFrame.rowsUnbounded();
6770
}
6871
boolean rows = windowFrame.getType() == WindowFrame.FrameType.ROWS;
6972
RexWindowBound lowerBound = convert(context, windowFrame.getLower());
@@ -99,6 +102,14 @@ static RexNode makeOver(
99102
return variance(context, field, partitions, rows, lowerBound, upperBound, true, false);
100103
case VARSAMP:
101104
return variance(context, field, partitions, rows, lowerBound, upperBound, false, false);
105+
case ROW_NUMBER:
106+
return withOver(
107+
context.relBuilder.aggregateCall(SqlStdOperatorTable.ROW_NUMBER),
108+
partitions,
109+
orderKeys,
110+
true,
111+
lowerBound,
112+
upperBound);
102113
default:
103114
return withOver(
104115
makeAggCall(context, functionName, false, field, argList),
@@ -151,6 +162,25 @@ private static RexNode withOver(
151162
.toRex();
152163
}
153164

165+
private static RexNode withOver(
166+
RelBuilder.AggCall aggCall,
167+
List<RexNode> partitions,
168+
List<RexNode> orderKeys,
169+
boolean rows,
170+
RexWindowBound lowerBound,
171+
RexWindowBound upperBound) {
172+
return aggCall
173+
.over()
174+
.partitionBy(partitions)
175+
.orderBy(orderKeys)
176+
.let(
177+
c ->
178+
rows
179+
? c.rowsBetween(lowerBound, upperBound)
180+
: c.rangeBetween(lowerBound, upperBound))
181+
.toRex();
182+
}
183+
154184
private static RexNode variance(
155185
CalcitePlanContext ctx,
156186
RexNode operator,

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ public enum BuiltinFunctionName {
330330
.put("stddev", BuiltinFunctionName.STDDEV_POP)
331331
.put("stddev_pop", BuiltinFunctionName.STDDEV_POP)
332332
.put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP)
333+
.put("rank", BuiltinFunctionName.RANK)
333334
.build();
334335

335336
public static Optional<BuiltinFunctionName> of(String str) {

docs/user/ppl/cmd/rare.rst

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@ Syntax
1919
============
2020
rare <field-list> [by-clause]
2121

22+
rare <field-list> [rare-options] [by-clause] ``(available from 3.1.0+)``
23+
2224
* field-list: mandatory. comma-delimited list of field names.
2325
* by-clause: optional. one or more fields to group the results by.
26+
* rare-options: optional. options for the rare command. Supported syntax is [countfield=<string>] [showcount=<bool>].
27+
* showcount=<bool>: optional. whether to create a field in output that represent a count of the tuple of values. Default value is ``true``.
28+
* countfield=<string>: optional. the name of the field that contains count. Default value is ``'count'``.
2429

2530

2631
Example 1: Find the least common values in a field
@@ -58,6 +63,39 @@ PPL query::
5863
| M | 36 |
5964
+--------+-----+
6065

66+
Example 3: Rare command with Calcite enabled
67+
============================================
68+
69+
The example finds least common gender of all the accounts when ``plugins.calcite.enabled`` is true.
70+
71+
PPL query::
72+
73+
PPL> source=accounts | rare gender;
74+
fetched row
75+
+--------+-------+
76+
| gender | count |
77+
|--------+-------|
78+
| F | 1 |
79+
| M | 3 |
80+
+--------+-------+
81+
82+
83+
Example 4: Specify the count field option
84+
=========================================
85+
86+
The example specifies the count field when ``plugins.calcite.enabled`` is true.
87+
88+
PPL query::
89+
90+
PPL> source=accounts | rare countfield='cnt' gender;
91+
fetched row
92+
+--------+-----+
93+
| gender | cnt |
94+
|--------+-----|
95+
| F | 1 |
96+
| M | 3 |
97+
+--------+-----+
98+
6199
Limitation
62100
==========
63101
The ``rare`` command is not rewritten to OpenSearch DSL, it is only executed on the coordination node.

0 commit comments

Comments
 (0)