Skip to content

Commit 9cdcafe

Browse files
committed
Merge remote-tracking branch 'upstream/main' into issues/3614
2 parents e588158 + 92cb089 commit 9cdcafe

99 files changed

Lines changed: 1679 additions & 433 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

core/src/main/java/org/opensearch/sql/analysis/Analyzer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) {
382382
fields.forEach(
383383
field -> newEnv.define(new Symbol(Namespace.FIELD_NAME, field.toString()), field.type()));
384384

385-
List<Argument> options = node.getNoOfResults();
385+
List<Argument> options = node.getArguments();
386386
Integer noOfResults = (Integer) options.get(0).getValue().getValue();
387387

388388
return new LogicalRareTopN(child, node.getCommandType(), noOfResults, fields, groupBys);

core/src/main/java/org/opensearch/sql/ast/expression/Argument.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import java.util.Arrays;
99
import java.util.List;
10+
import java.util.Map;
1011
import lombok.EqualsAndHashCode;
1112
import lombok.Getter;
1213
import lombok.RequiredArgsConstructor;
@@ -32,4 +33,29 @@ public List<UnresolvedExpression> getChild() {
3233
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
3334
return nodeVisitor.visitArgument(this, context);
3435
}
36+
37+
/** ArgumentMap is a helper class to get argument value by name. */
38+
public static class ArgumentMap {
39+
private final Map<String, Literal> map;
40+
41+
public ArgumentMap(List<Argument> arguments) {
42+
this.map =
43+
arguments.stream()
44+
.collect(java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue));
45+
}
46+
47+
public static ArgumentMap of(List<Argument> arguments) {
48+
return new ArgumentMap(arguments);
49+
}
50+
51+
/**
52+
* Get argument value by name.
53+
*
54+
* @param name argument name
55+
* @return argument value
56+
*/
57+
public Literal get(String name) {
58+
return map.get(name);
59+
}
60+
}
3561
}

core/src/main/java/org/opensearch/sql/ast/expression/WindowFrame.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import lombok.Getter;
1111
import lombok.RequiredArgsConstructor;
1212
import lombok.ToString;
13+
import org.opensearch.sql.ast.dsl.AstDSL;
1314

1415
@EqualsAndHashCode(callSuper = false)
1516
@Getter
@@ -25,12 +26,25 @@ public enum FrameType {
2526
ROWS
2627
}
2728

28-
public static WindowFrame defaultFrame() {
29-
return new WindowFrame(
30-
FrameType.ROWS, createBound("UNBOUNDED PRECEDING"), createBound("UNBOUNDED FOLLOWING"));
29+
public static WindowFrame rowsUnbounded() {
30+
return WindowFrame.of(
31+
FrameType.ROWS,
32+
AstDSL.stringLiteral("UNBOUNDED PRECEDING"),
33+
AstDSL.stringLiteral("UNBOUNDED FOLLOWING"));
3134
}
3235

33-
public static WindowFrame create(FrameType type, Literal lower, Literal upper) {
36+
public static WindowFrame toCurrentRow() {
37+
return WindowFrame.of(
38+
FrameType.ROWS,
39+
AstDSL.stringLiteral("UNBOUNDED PRECEDING"),
40+
AstDSL.stringLiteral("CURRENT ROW"));
41+
}
42+
43+
public static WindowFrame of(FrameType type, String lower, String upper) {
44+
return WindowFrame.of(type, AstDSL.stringLiteral(lower), AstDSL.stringLiteral(upper));
45+
}
46+
47+
public static WindowFrame of(FrameType type, Literal lower, Literal upper) {
3448
WindowBound lowerBound = null;
3549
WindowBound upperBound = null;
3650
if (lower != null) {

core/src/main/java/org/opensearch/sql/ast/expression/WindowFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class WindowFunction extends UnresolvedExpression {
2626
private final UnresolvedExpression function;
2727
@Setter private List<UnresolvedExpression> partitionByList = new ArrayList<>();
2828
@Setter private List<Pair<SortOption, UnresolvedExpression>> sortList = new ArrayList<>();
29-
@Setter private WindowFrame windowFrame = WindowFrame.defaultFrame();
29+
@Setter private WindowFrame windowFrame = WindowFrame.rowsUnbounded();
3030

3131
public WindowFunction(
3232
UnresolvedExpression function,

core/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ public class RareTopN extends UnresolvedPlan {
2929

3030
private UnresolvedPlan child;
3131
private final CommandType commandType;
32-
private final List<Argument> noOfResults;
32+
// arguments: noOfResults: Integer, countField: String, showCount: Boolean
33+
private final List<Argument> arguments;
3334
private final List<Field> fields;
3435
private final List<UnresolvedExpression> groupExprList;
3536

core/src/main/java/org/opensearch/sql/calcite/CalciteAggCallVisitor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ public AggCall visitAlias(Alias node, CalcitePlanContext context) {
3737

3838
@Override
3939
public AggCall visitAggregateFunction(AggregateFunction node, CalcitePlanContext context) {
40-
RexNode field = rexNodeVisitor.analyze(node.getField(), context);
40+
RexNode field =
41+
node.getField() == null ? null : rexNodeVisitor.analyze(node.getField(), context);
4142
List<RexNode> argList = new ArrayList<>();
4243
for (UnresolvedExpression arg : node.getArgList()) {
4344
argList.add(rexNodeVisitor.analyze(arg, context));

core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class CalcitePlanContext {
3131
public final ExtendedRexBuilder rexBuilder;
3232
public final FunctionProperties functionProperties;
3333
public final QueryType queryType;
34+
public final Integer querySizeLimit;
3435

3536
@Getter @Setter private boolean isResolvingJoinCondition = false;
3637
@Getter @Setter private boolean isResolvingSubquery = false;
@@ -46,8 +47,9 @@ public class CalcitePlanContext {
4647
private final Stack<RexCorrelVariable> correlVar = new Stack<>();
4748
private final Stack<List<RexNode>> windowPartitions = new Stack<>();
4849

49-
private CalcitePlanContext(FrameworkConfig config, QueryType queryType) {
50+
private CalcitePlanContext(FrameworkConfig config, Integer querySizeLimit, QueryType queryType) {
5051
this.config = config;
52+
this.querySizeLimit = querySizeLimit;
5153
this.queryType = queryType;
5254
this.connection = CalciteToolsHelper.connect(config, TYPE_FACTORY);
5355
this.relBuilder = CalciteToolsHelper.create(config, TYPE_FACTORY, connection);
@@ -84,7 +86,8 @@ public Optional<RexCorrelVariable> peekCorrelVar() {
8486
}
8587
}
8688

87-
public static CalcitePlanContext create(FrameworkConfig config, QueryType queryType) {
88-
return new CalcitePlanContext(config, queryType);
89+
public static CalcitePlanContext create(
90+
FrameworkConfig config, Integer querySizeLimit, QueryType queryType) {
91+
return new CalcitePlanContext(config, querySizeLimit, queryType);
8992
}
9093
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 138 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC;
1212
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
1313
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
14+
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME;
1415

1516
import com.google.common.base.Strings;
1617
import com.google.common.collect.ImmutableList;
@@ -26,6 +27,7 @@
2627
import org.apache.calcite.plan.RelOptTable;
2728
import org.apache.calcite.plan.ViewExpanders;
2829
import org.apache.calcite.rel.RelNode;
30+
import org.apache.calcite.rel.core.Aggregate;
2931
import org.apache.calcite.rel.type.RelDataTypeField;
3032
import org.apache.calcite.rex.RexCall;
3133
import org.apache.calcite.rex.RexCorrelVariable;
@@ -42,14 +44,17 @@
4244
import org.checkerframework.checker.nullness.qual.Nullable;
4345
import org.opensearch.sql.ast.AbstractNodeVisitor;
4446
import org.opensearch.sql.ast.Node;
47+
import org.opensearch.sql.ast.dsl.AstDSL;
4548
import org.opensearch.sql.ast.expression.AllFields;
4649
import org.opensearch.sql.ast.expression.AllFieldsExcludeMeta;
4750
import org.opensearch.sql.ast.expression.Argument;
51+
import org.opensearch.sql.ast.expression.Argument.ArgumentMap;
4852
import org.opensearch.sql.ast.expression.Field;
4953
import org.opensearch.sql.ast.expression.Let;
5054
import org.opensearch.sql.ast.expression.Literal;
5155
import org.opensearch.sql.ast.expression.ParseMethod;
5256
import org.opensearch.sql.ast.expression.UnresolvedExpression;
57+
import org.opensearch.sql.ast.expression.WindowFrame;
5358
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
5459
import org.opensearch.sql.ast.tree.AD;
5560
import org.opensearch.sql.ast.tree.Aggregation;
@@ -83,6 +88,7 @@
8388
import org.opensearch.sql.calcite.utils.PlanUtils;
8489
import org.opensearch.sql.exception.CalciteUnsupportedException;
8590
import org.opensearch.sql.exception.SemanticCheckException;
91+
import org.opensearch.sql.expression.function.BuiltinFunctionName;
8692
import org.opensearch.sql.expression.function.PPLFuncImpTable;
8793
import org.opensearch.sql.utils.ParseUtils;
8894

@@ -374,70 +380,102 @@ private void projectPlusOverriding(
374380
context.relBuilder.rename(expectedRenameFields);
375381
}
376382

377-
private Pair<List<AggCall>, List<RexNode>> resolveAggCallAndGroupBy(
378-
Aggregation node, CalcitePlanContext context) {
383+
/**
384+
* Resolve the aggregation with trimming unused fields to avoid bugs in {@link
385+
* org.apache.calcite.sql2rel.RelDecorrelator#decorrelateRel(Aggregate, boolean)}
386+
*
387+
* @param groupExprList group by expression list
388+
* @param aggExprList aggregate expression list
389+
* @param context CalcitePlanContext
390+
* @return Pair of (group-by list, field list, aggregate list)
391+
*/
392+
private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
393+
List<UnresolvedExpression> groupExprList,
394+
List<UnresolvedExpression> aggExprList,
395+
CalcitePlanContext context) {
396+
// Example 1: source=t | where a > 1 | stats avg(b + 1) by c
397+
// Before: Aggregate(avg(b + 1))
398+
// \- Filter(a > 1)
399+
// \- Scan t
400+
// After: Aggregate(avg(b + 1))
401+
// \- Project([c, b])
402+
// \- Filter(a > 1)
403+
// \- Scan t
404+
//
405+
// Example 2: source=t | where a > 1 | top b by c
406+
// Before: Aggregate(count)
407+
// \-Filter(a > 1)
408+
// \- Scan t
409+
// After: Aggregate(count)
410+
// \- Project([c, b])
411+
// \- Filter(a > 1)
412+
// \- Scan t
413+
Pair<List<RexNode>, List<AggCall>> resolved =
414+
resolveAttributesForAggregation(groupExprList, aggExprList, context);
415+
List<RexInputRef> trimmedRefs = new ArrayList<>();
416+
trimmedRefs.addAll(PlanUtils.getInputRefs(resolved.getLeft())); // group-by keys first
417+
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolved.getRight()));
418+
context.relBuilder.project(trimmedRefs);
419+
420+
// Re-resolve all attributes based on adding trimmed Project.
421+
// Using re-resolving rather than Calcite Mapping (ref Calcite ProjectTableScanRule)
422+
// because that Mapping only works for RexNode, but we need both AggCall and RexNode list.
423+
Pair<List<RexNode>, List<AggCall>> reResolved =
424+
resolveAttributesForAggregation(groupExprList, aggExprList, context);
425+
context.relBuilder.aggregate(
426+
context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight());
427+
return Pair.of(reResolved.getLeft(), reResolved.getRight());
428+
}
429+
430+
/**
431+
* Resolve attributes for aggregation.
432+
*
433+
* @param groupExprList group by expression list
434+
* @param aggExprList aggregate expression list
435+
* @param context CalcitePlanContext
436+
* @return Pair of (group-by list, aggregate list)
437+
*/
438+
private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation(
439+
List<UnresolvedExpression> groupExprList,
440+
List<UnresolvedExpression> aggExprList,
441+
CalcitePlanContext context) {
379442
List<AggCall> aggCallList =
380-
node.getAggExprList().stream()
381-
.map(expr -> aggVisitor.analyze(expr, context))
382-
.collect(Collectors.toList());
383-
// The span column is always the first column in result whatever
384-
// the order of span in query is first or last one
385-
List<RexNode> groupByList = new ArrayList<>();
386-
UnresolvedExpression span = node.getSpan();
387-
if (!Objects.isNull(span)) {
388-
RexNode spanRex = rexVisitor.analyze(span, context);
389-
groupByList.add(spanRex);
390-
// add span's group alias field (most recent added expression)
391-
}
392-
groupByList.addAll(
393-
node.getGroupExprList().stream().map(expr -> rexVisitor.analyze(expr, context)).toList());
394-
return Pair.of(aggCallList, groupByList);
443+
aggExprList.stream().map(expr -> aggVisitor.analyze(expr, context)).toList();
444+
List<RexNode> groupByList =
445+
groupExprList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
446+
return Pair.of(groupByList, aggCallList);
395447
}
396448

397449
@Override
398450
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
399451
visitChildren(node, context);
400-
// Add a trimmed Project before Aggregate.
401-
// to avoid bugs in RelDecorrelator.decorrelateRel(Aggregate rel)
402-
// For example:
403-
// source=t | where a > 1 | stats avg(b+1) by c
404-
// Before:
405-
// Aggregate
406-
// \- Filter(a>1)
407-
// \- Scan t
408-
// After:
409-
// Aggregate
410-
// \- Project([c,b])
411-
// \- Filter(a>1)
412-
// \- Scan t
413-
Pair<List<AggCall>, List<RexNode>> resolved = resolveAggCallAndGroupBy(node, context);
414-
List<RexInputRef> trimmedRefs = new ArrayList<>();
415-
trimmedRefs.addAll(PlanUtils.getInputRefs(resolved.getRight())); // group-by keys first
416-
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolved.getLeft()));
417-
context.relBuilder.project(trimmedRefs);
418452

419-
// Re-resolve aggCalls and group-by list based on adding trimmed Project.
420-
// Using re-resolving rather than Calcite Mapping (ref Calcite ProjectTableScanRule)
421-
// because that Mapping only works for RexNode, but we need both AggCall and RexNode list.
422-
Pair<List<AggCall>, List<RexNode>> reResolved = resolveAggCallAndGroupBy(node, context);
423-
List<AggCall> aggList = reResolved.getLeft();
424-
List<RexNode> groupByList = reResolved.getRight();
425-
context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList);
453+
List<UnresolvedExpression> aggExprList = node.getAggExprList();
454+
List<UnresolvedExpression> groupExprList = new ArrayList<>();
455+
// The span column is always the first column in result whatever
456+
// the order of span in query is first or last one
457+
UnresolvedExpression span = node.getSpan();
458+
if (!Objects.isNull(span)) {
459+
groupExprList.add(span);
460+
}
461+
groupExprList.addAll(node.getGroupExprList());
462+
Pair<List<RexNode>, List<AggCall>> aggregationAttributes =
463+
aggregateWithTrimming(groupExprList, aggExprList, context);
426464

427465
// schema reordering
428466
// As an example, in command `stats count() by colA, colB`,
429467
// the sequence of output schema is "count, colA, colB".
430468
List<RexNode> outputFields = context.relBuilder.fields();
431469
int numOfOutputFields = outputFields.size();
432-
int numOfAggList = aggList.size();
470+
int numOfAggList = aggExprList.size();
433471
List<RexNode> reordered = new ArrayList<>(numOfOutputFields);
434472
// Add aggregation results first
435473
List<RexNode> aggRexList =
436474
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
437475
reordered.addAll(aggRexList);
438476
// Add group by columns
439477
List<RexNode> aliasedGroupByList =
440-
groupByList.stream()
478+
aggregationAttributes.getLeft().stream()
441479
.map(this::extractAliasLiteral)
442480
.flatMap(Optional::stream)
443481
.map(ref -> ((RexLiteral) ref).getValueAs(String.class))
@@ -742,7 +780,62 @@ public RelNode visitKmeans(Kmeans node, CalcitePlanContext context) {
742780

743781
@Override
744782
public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) {
745-
throw new CalciteUnsupportedException("Rare and Top commands are unsupported in Calcite");
783+
visitChildren(node, context);
784+
785+
ArgumentMap arguments = ArgumentMap.of(node.getArguments());
786+
String countFieldName = (String) arguments.get("countField").getValue();
787+
if (context.relBuilder.peek().getRowType().getFieldNames().contains(countFieldName)) {
788+
throw new IllegalArgumentException(
789+
"Field `"
790+
+ countFieldName
791+
+ "` is existed, change the count field by setting countfield='xyz'");
792+
}
793+
794+
// 1. group the group-by list + field list and add a count() aggregation
795+
List<UnresolvedExpression> groupExprList = new ArrayList<>(node.getGroupExprList());
796+
List<UnresolvedExpression> fieldList =
797+
node.getFields().stream().map(f -> (UnresolvedExpression) f).toList();
798+
groupExprList.addAll(fieldList);
799+
List<UnresolvedExpression> aggExprList =
800+
List.of(AstDSL.alias(countFieldName, AstDSL.aggregate("count", null)));
801+
aggregateWithTrimming(groupExprList, aggExprList, context);
802+
803+
// 2. add a window column
804+
List<RexNode> partitionKeys = rexVisitor.analyze(node.getGroupExprList(), context);
805+
RexNode countField;
806+
if (node.getCommandType() == RareTopN.CommandType.TOP) {
807+
countField = context.relBuilder.desc(context.relBuilder.field(countFieldName));
808+
} else {
809+
countField = context.relBuilder.field(countFieldName);
810+
}
811+
RexNode rowNumberWindowOver =
812+
PlanUtils.makeOver(
813+
context,
814+
BuiltinFunctionName.ROW_NUMBER,
815+
null,
816+
List.of(),
817+
partitionKeys,
818+
List.of(countField),
819+
WindowFrame.toCurrentRow());
820+
context.relBuilder.projectPlus(
821+
context.relBuilder.alias(rowNumberWindowOver, ROW_NUMBER_COLUMN_NAME));
822+
823+
// 3. filter row_number() <= k in each partition
824+
Integer N = (Integer) arguments.get("noOfResults").getValue();
825+
context.relBuilder.filter(
826+
context.relBuilder.lessThanOrEqual(
827+
context.relBuilder.field(ROW_NUMBER_COLUMN_NAME), context.relBuilder.literal(N)));
828+
829+
// 4. project final output. the default output is group by list + field list
830+
Boolean showCount = (Boolean) arguments.get("showCount").getValue();
831+
if (showCount) {
832+
context.relBuilder.projectExcept(context.relBuilder.field(ROW_NUMBER_COLUMN_NAME));
833+
} else {
834+
context.relBuilder.projectExcept(
835+
context.relBuilder.field(ROW_NUMBER_COLUMN_NAME),
836+
context.relBuilder.field(countFieldName));
837+
}
838+
return context.relBuilder.peek();
746839
}
747840

748841
@Override

0 commit comments

Comments
 (0)