|
11 | 11 | import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC; |
12 | 12 | import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; |
13 | 13 | import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; |
| 14 | +import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME; |
14 | 15 |
|
15 | 16 | import com.google.common.base.Strings; |
16 | 17 | import com.google.common.collect.ImmutableList; |
|
26 | 27 | import org.apache.calcite.plan.RelOptTable; |
27 | 28 | import org.apache.calcite.plan.ViewExpanders; |
28 | 29 | import org.apache.calcite.rel.RelNode; |
| 30 | +import org.apache.calcite.rel.core.Aggregate; |
29 | 31 | import org.apache.calcite.rel.type.RelDataTypeField; |
30 | 32 | import org.apache.calcite.rex.RexCall; |
31 | 33 | import org.apache.calcite.rex.RexCorrelVariable; |
|
42 | 44 | import org.checkerframework.checker.nullness.qual.Nullable; |
43 | 45 | import org.opensearch.sql.ast.AbstractNodeVisitor; |
44 | 46 | import org.opensearch.sql.ast.Node; |
| 47 | +import org.opensearch.sql.ast.dsl.AstDSL; |
45 | 48 | import org.opensearch.sql.ast.expression.AllFields; |
46 | 49 | import org.opensearch.sql.ast.expression.AllFieldsExcludeMeta; |
47 | 50 | import org.opensearch.sql.ast.expression.Argument; |
| 51 | +import org.opensearch.sql.ast.expression.Argument.ArgumentMap; |
48 | 52 | import org.opensearch.sql.ast.expression.Field; |
49 | 53 | import org.opensearch.sql.ast.expression.Let; |
50 | 54 | import org.opensearch.sql.ast.expression.Literal; |
51 | 55 | import org.opensearch.sql.ast.expression.ParseMethod; |
52 | 56 | import org.opensearch.sql.ast.expression.UnresolvedExpression; |
| 57 | +import org.opensearch.sql.ast.expression.WindowFrame; |
53 | 58 | import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; |
54 | 59 | import org.opensearch.sql.ast.tree.AD; |
55 | 60 | import org.opensearch.sql.ast.tree.Aggregation; |
|
83 | 88 | import org.opensearch.sql.calcite.utils.PlanUtils; |
84 | 89 | import org.opensearch.sql.exception.CalciteUnsupportedException; |
85 | 90 | import org.opensearch.sql.exception.SemanticCheckException; |
| 91 | +import org.opensearch.sql.expression.function.BuiltinFunctionName; |
86 | 92 | import org.opensearch.sql.expression.function.PPLFuncImpTable; |
87 | 93 | import org.opensearch.sql.utils.ParseUtils; |
88 | 94 |
|
@@ -374,70 +380,102 @@ private void projectPlusOverriding( |
374 | 380 | context.relBuilder.rename(expectedRenameFields); |
375 | 381 | } |
376 | 382 |
|
377 | | - private Pair<List<AggCall>, List<RexNode>> resolveAggCallAndGroupBy( |
378 | | - Aggregation node, CalcitePlanContext context) { |
| 383 | + /** |
| 384 | + * Resolve the aggregation with trimming unused fields to avoid bugs in {@link |
| 385 | + * org.apache.calcite.sql2rel.RelDecorrelator#decorrelateRel(Aggregate, boolean)} |
| 386 | + * |
| 387 | + * @param groupExprList group by expression list |
| 388 | + * @param aggExprList aggregate expression list |
| 389 | + * @param context CalcitePlanContext |
| 390 | + * @return Pair of (group-by list, field list, aggregate list) |
| 391 | + */ |
| 392 | + private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming( |
| 393 | + List<UnresolvedExpression> groupExprList, |
| 394 | + List<UnresolvedExpression> aggExprList, |
| 395 | + CalcitePlanContext context) { |
| 396 | + // Example 1: source=t | where a > 1 | stats avg(b + 1) by c |
| 397 | + // Before: Aggregate(avg(b + 1)) |
| 398 | + // \- Filter(a > 1) |
| 399 | + // \- Scan t |
| 400 | + // After: Aggregate(avg(b + 1)) |
| 401 | + // \- Project([c, b]) |
| 402 | + // \- Filter(a > 1) |
| 403 | + // \- Scan t |
| 404 | + // |
| 405 | + // Example 2: source=t | where a > 1 | top b by c |
| 406 | + // Before: Aggregate(count) |
| 407 | + // \-Filter(a > 1) |
| 408 | + // \- Scan t |
| 409 | + // After: Aggregate(count) |
| 410 | + // \- Project([c, b]) |
| 411 | + // \- Filter(a > 1) |
| 412 | + // \- Scan t |
| 413 | + Pair<List<RexNode>, List<AggCall>> resolved = |
| 414 | + resolveAttributesForAggregation(groupExprList, aggExprList, context); |
| 415 | + List<RexInputRef> trimmedRefs = new ArrayList<>(); |
| 416 | + trimmedRefs.addAll(PlanUtils.getInputRefs(resolved.getLeft())); // group-by keys first |
| 417 | + trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolved.getRight())); |
| 418 | + context.relBuilder.project(trimmedRefs); |
| 419 | + |
| 420 | + // Re-resolve all attributes based on adding trimmed Project. |
| 421 | + // Using re-resolving rather than Calcite Mapping (ref Calcite ProjectTableScanRule) |
| 422 | + // because that Mapping only works for RexNode, but we need both AggCall and RexNode list. |
| 423 | + Pair<List<RexNode>, List<AggCall>> reResolved = |
| 424 | + resolveAttributesForAggregation(groupExprList, aggExprList, context); |
| 425 | + context.relBuilder.aggregate( |
| 426 | + context.relBuilder.groupKey(reResolved.getLeft()), reResolved.getRight()); |
| 427 | + return Pair.of(reResolved.getLeft(), reResolved.getRight()); |
| 428 | + } |
| 429 | + |
| 430 | + /** |
| 431 | + * Resolve attributes for aggregation. |
| 432 | + * |
| 433 | + * @param groupExprList group by expression list |
| 434 | + * @param aggExprList aggregate expression list |
| 435 | + * @param context CalcitePlanContext |
| 436 | + * @return Pair of (group-by list, aggregate list) |
| 437 | + */ |
| 438 | + private Pair<List<RexNode>, List<AggCall>> resolveAttributesForAggregation( |
| 439 | + List<UnresolvedExpression> groupExprList, |
| 440 | + List<UnresolvedExpression> aggExprList, |
| 441 | + CalcitePlanContext context) { |
379 | 442 | List<AggCall> aggCallList = |
380 | | - node.getAggExprList().stream() |
381 | | - .map(expr -> aggVisitor.analyze(expr, context)) |
382 | | - .collect(Collectors.toList()); |
383 | | - // The span column is always the first column in result whatever |
384 | | - // the order of span in query is first or last one |
385 | | - List<RexNode> groupByList = new ArrayList<>(); |
386 | | - UnresolvedExpression span = node.getSpan(); |
387 | | - if (!Objects.isNull(span)) { |
388 | | - RexNode spanRex = rexVisitor.analyze(span, context); |
389 | | - groupByList.add(spanRex); |
390 | | - // add span's group alias field (most recent added expression) |
391 | | - } |
392 | | - groupByList.addAll( |
393 | | - node.getGroupExprList().stream().map(expr -> rexVisitor.analyze(expr, context)).toList()); |
394 | | - return Pair.of(aggCallList, groupByList); |
| 443 | + aggExprList.stream().map(expr -> aggVisitor.analyze(expr, context)).toList(); |
| 444 | + List<RexNode> groupByList = |
| 445 | + groupExprList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList(); |
| 446 | + return Pair.of(groupByList, aggCallList); |
395 | 447 | } |
396 | 448 |
|
397 | 449 | @Override |
398 | 450 | public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) { |
399 | 451 | visitChildren(node, context); |
400 | | - // Add a trimmed Project before Aggregate. |
401 | | - // to avoid bugs in RelDecorrelator.decorrelateRel(Aggregate rel) |
402 | | - // For example: |
403 | | - // source=t | where a > 1 | stats avg(b+1) by c |
404 | | - // Before: |
405 | | - // Aggregate |
406 | | - // \- Filter(a>1) |
407 | | - // \- Scan t |
408 | | - // After: |
409 | | - // Aggregate |
410 | | - // \- Project([c,b]) |
411 | | - // \- Filter(a>1) |
412 | | - // \- Scan t |
413 | | - Pair<List<AggCall>, List<RexNode>> resolved = resolveAggCallAndGroupBy(node, context); |
414 | | - List<RexInputRef> trimmedRefs = new ArrayList<>(); |
415 | | - trimmedRefs.addAll(PlanUtils.getInputRefs(resolved.getRight())); // group-by keys first |
416 | | - trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolved.getLeft())); |
417 | | - context.relBuilder.project(trimmedRefs); |
418 | 452 |
|
419 | | - // Re-resolve aggCalls and group-by list based on adding trimmed Project. |
420 | | - // Using re-resolving rather than Calcite Mapping (ref Calcite ProjectTableScanRule) |
421 | | - // because that Mapping only works for RexNode, but we need both AggCall and RexNode list. |
422 | | - Pair<List<AggCall>, List<RexNode>> reResolved = resolveAggCallAndGroupBy(node, context); |
423 | | - List<AggCall> aggList = reResolved.getLeft(); |
424 | | - List<RexNode> groupByList = reResolved.getRight(); |
425 | | - context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList); |
| 453 | + List<UnresolvedExpression> aggExprList = node.getAggExprList(); |
| 454 | + List<UnresolvedExpression> groupExprList = new ArrayList<>(); |
| 455 | + // The span column is always the first column in result whatever |
| 456 | + // the order of span in query is first or last one |
| 457 | + UnresolvedExpression span = node.getSpan(); |
| 458 | + if (!Objects.isNull(span)) { |
| 459 | + groupExprList.add(span); |
| 460 | + } |
| 461 | + groupExprList.addAll(node.getGroupExprList()); |
| 462 | + Pair<List<RexNode>, List<AggCall>> aggregationAttributes = |
| 463 | + aggregateWithTrimming(groupExprList, aggExprList, context); |
426 | 464 |
|
427 | 465 | // schema reordering |
428 | 466 | // As an example, in command `stats count() by colA, colB`, |
429 | 467 | // the sequence of output schema is "count, colA, colB". |
430 | 468 | List<RexNode> outputFields = context.relBuilder.fields(); |
431 | 469 | int numOfOutputFields = outputFields.size(); |
432 | | - int numOfAggList = aggList.size(); |
| 470 | + int numOfAggList = aggExprList.size(); |
433 | 471 | List<RexNode> reordered = new ArrayList<>(numOfOutputFields); |
434 | 472 | // Add aggregation results first |
435 | 473 | List<RexNode> aggRexList = |
436 | 474 | outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields); |
437 | 475 | reordered.addAll(aggRexList); |
438 | 476 | // Add group by columns |
439 | 477 | List<RexNode> aliasedGroupByList = |
440 | | - groupByList.stream() |
| 478 | + aggregationAttributes.getLeft().stream() |
441 | 479 | .map(this::extractAliasLiteral) |
442 | 480 | .flatMap(Optional::stream) |
443 | 481 | .map(ref -> ((RexLiteral) ref).getValueAs(String.class)) |
@@ -742,7 +780,62 @@ public RelNode visitKmeans(Kmeans node, CalcitePlanContext context) { |
742 | 780 |
|
743 | 781 | @Override |
744 | 782 | public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) { |
745 | | - throw new CalciteUnsupportedException("Rare and Top commands are unsupported in Calcite"); |
| 783 | + visitChildren(node, context); |
| 784 | + |
| 785 | + ArgumentMap arguments = ArgumentMap.of(node.getArguments()); |
| 786 | + String countFieldName = (String) arguments.get("countField").getValue(); |
| 787 | + if (context.relBuilder.peek().getRowType().getFieldNames().contains(countFieldName)) { |
| 788 | + throw new IllegalArgumentException( |
| 789 | + "Field `" |
| 790 | + + countFieldName |
| 791 | + + "` is existed, change the count field by setting countfield='xyz'"); |
| 792 | + } |
| 793 | + |
| 794 | + // 1. group the group-by list + field list and add a count() aggregation |
| 795 | + List<UnresolvedExpression> groupExprList = new ArrayList<>(node.getGroupExprList()); |
| 796 | + List<UnresolvedExpression> fieldList = |
| 797 | + node.getFields().stream().map(f -> (UnresolvedExpression) f).toList(); |
| 798 | + groupExprList.addAll(fieldList); |
| 799 | + List<UnresolvedExpression> aggExprList = |
| 800 | + List.of(AstDSL.alias(countFieldName, AstDSL.aggregate("count", null))); |
| 801 | + aggregateWithTrimming(groupExprList, aggExprList, context); |
| 802 | + |
| 803 | + // 2. add a window column |
| 804 | + List<RexNode> partitionKeys = rexVisitor.analyze(node.getGroupExprList(), context); |
| 805 | + RexNode countField; |
| 806 | + if (node.getCommandType() == RareTopN.CommandType.TOP) { |
| 807 | + countField = context.relBuilder.desc(context.relBuilder.field(countFieldName)); |
| 808 | + } else { |
| 809 | + countField = context.relBuilder.field(countFieldName); |
| 810 | + } |
| 811 | + RexNode rowNumberWindowOver = |
| 812 | + PlanUtils.makeOver( |
| 813 | + context, |
| 814 | + BuiltinFunctionName.ROW_NUMBER, |
| 815 | + null, |
| 816 | + List.of(), |
| 817 | + partitionKeys, |
| 818 | + List.of(countField), |
| 819 | + WindowFrame.toCurrentRow()); |
| 820 | + context.relBuilder.projectPlus( |
| 821 | + context.relBuilder.alias(rowNumberWindowOver, ROW_NUMBER_COLUMN_NAME)); |
| 822 | + |
| 823 | + // 3. filter row_number() <= k in each partition |
| 824 | + Integer N = (Integer) arguments.get("noOfResults").getValue(); |
| 825 | + context.relBuilder.filter( |
| 826 | + context.relBuilder.lessThanOrEqual( |
| 827 | + context.relBuilder.field(ROW_NUMBER_COLUMN_NAME), context.relBuilder.literal(N))); |
| 828 | + |
| 829 | + // 4. project final output. the default output is group by list + field list |
| 830 | + Boolean showCount = (Boolean) arguments.get("showCount").getValue(); |
| 831 | + if (showCount) { |
| 832 | + context.relBuilder.projectExcept(context.relBuilder.field(ROW_NUMBER_COLUMN_NAME)); |
| 833 | + } else { |
| 834 | + context.relBuilder.projectExcept( |
| 835 | + context.relBuilder.field(ROW_NUMBER_COLUMN_NAME), |
| 836 | + context.relBuilder.field(countFieldName)); |
| 837 | + } |
| 838 | + return context.relBuilder.peek(); |
746 | 839 | } |
747 | 840 |
|
748 | 841 | @Override |
|
0 commit comments