Skip to content

Commit 5733337

Browse files
committed
Merge remote-tracking branch 'origin/main' into type-checker
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
2 parents f3f94ce + e6116bc commit 5733337

17 files changed

Lines changed: 923 additions & 99 deletions

File tree

core/src/main/java/org/opensearch/sql/ast/statement/Explain.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
package org.opensearch.sql.ast.statement;
1010

11+
import java.util.Locale;
1112
import lombok.EqualsAndHashCode;
1213
import lombok.Getter;
1314
import org.opensearch.sql.ast.AbstractNodeVisitor;
@@ -46,7 +47,7 @@ public enum ExplainFormat {
4647

4748
public static ExplainFormat format(String format) {
4849
try {
49-
return ExplainFormat.valueOf(format.toUpperCase());
50+
return ExplainFormat.valueOf(format.toUpperCase(Locale.ROOT));
5051
} catch (Exception e) {
5152
return ExplainFormat.STANDARD;
5253
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.calcite.tools.RelBuilder;
3838
import org.apache.calcite.tools.RelBuilder.AggCall;
3939
import org.apache.calcite.util.Holder;
40+
import org.apache.calcite.util.Pair;
4041
import org.checkerframework.checker.nullness.qual.Nullable;
4142
import org.opensearch.sql.ast.AbstractNodeVisitor;
4243
import org.opensearch.sql.ast.Node;
@@ -78,6 +79,7 @@
7879
import org.opensearch.sql.ast.tree.Window;
7980
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
8081
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
82+
import org.opensearch.sql.calcite.utils.PlanUtils;
8183
import org.opensearch.sql.exception.CalciteUnsupportedException;
8284
import org.opensearch.sql.exception.SemanticCheckException;
8385
import org.opensearch.sql.expression.function.PPLFuncImpTable;
@@ -371,10 +373,9 @@ private void projectPlusOverriding(
371373
context.relBuilder.rename(expectedRenameFields);
372374
}
373375

374-
@Override
375-
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
376-
visitChildren(node, context);
377-
List<AggCall> aggList =
376+
private Pair<List<AggCall>, List<RexNode>> resolveAggCallAndGroupBy(
377+
Aggregation node, CalcitePlanContext context) {
378+
List<AggCall> aggCallList =
378379
node.getAggExprList().stream()
379380
.map(expr -> aggVisitor.analyze(expr, context))
380381
.collect(Collectors.toList());
@@ -389,7 +390,37 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
389390
}
390391
groupByList.addAll(
391392
node.getGroupExprList().stream().map(expr -> rexVisitor.analyze(expr, context)).toList());
393+
return Pair.of(aggCallList, groupByList);
394+
}
392395

396+
@Override
397+
public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
398+
visitChildren(node, context);
399+
// Add a trimmed Project before Aggregate.
400+
// to avoid bugs in RelDecorrelator.decorrelateRel(Aggregate rel)
401+
// For example:
402+
// source=t | where a > 1 | stats avg(b+1) by c
403+
// Before:
404+
// Aggregate
405+
// \- Filter(a>1)
406+
// \- Scan t
407+
// After:
408+
// Aggregate
409+
// \- Project([c,b])
410+
// \- Filter(a>1)
411+
// \- Scan t
412+
Pair<List<AggCall>, List<RexNode>> resolved = resolveAggCallAndGroupBy(node, context);
413+
List<RexInputRef> trimmedRefs = new ArrayList<>();
414+
trimmedRefs.addAll(PlanUtils.getInputRefs(resolved.right)); // group-by keys first
415+
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolved.left));
416+
context.relBuilder.project(trimmedRefs);
417+
418+
// Re-resolve aggCalls and group-by list based on adding trimmed Project.
419+
// Using re-resolving rather than Calcite Mapping (ref Calcite ProjectTableScanRule)
420+
// because that Mapping only works for RexNode, but we need both AggCall and RexNode list.
421+
Pair<List<AggCall>, List<RexNode>> reResolved = resolveAggCallAndGroupBy(node, context);
422+
List<AggCall> aggList = reResolved.left;
423+
List<RexNode> groupByList = reResolved.right;
393424
context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList);
394425

395426
// schema reordering

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import java.util.ArrayList;
2121
import java.util.List;
2222
import javax.annotation.Nullable;
23+
import org.apache.calcite.rex.RexInputRef;
2324
import org.apache.calcite.rex.RexNode;
25+
import org.apache.calcite.rex.RexVisitorImpl;
2426
import org.apache.calcite.rex.RexWindowBound;
2527
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
2628
import org.apache.calcite.sql.type.ReturnTypes;
@@ -259,4 +261,34 @@ static RelBuilder.AggCall makeAggCall(
259261
"Unexpected aggregation: " + functionName.getName().getFunctionName());
260262
}
261263
}
264+
265+
/** Get all uniq input references from a RexNode. */
266+
static List<RexInputRef> getInputRefs(RexNode node) {
267+
List<RexInputRef> inputRefs = new ArrayList<>();
268+
node.accept(
269+
new RexVisitorImpl<Void>(true) {
270+
@Override
271+
public Void visitInputRef(RexInputRef inputRef) {
272+
if (!inputRefs.contains(inputRef)) {
273+
inputRefs.add(inputRef);
274+
}
275+
return null;
276+
}
277+
});
278+
return inputRefs;
279+
}
280+
281+
/** Get all uniq input references from a list of RexNodes. */
282+
static List<RexInputRef> getInputRefs(List<RexNode> nodes) {
283+
return nodes.stream().flatMap(node -> getInputRefs(node).stream()).toList();
284+
}
285+
286+
/** Get all uniq input references from a list of agg calls. */
287+
static List<RexInputRef> getInputRefsFromAggCall(List<RelBuilder.AggCall> aggCalls) {
288+
return aggCalls.stream()
289+
.map(RelBuilder.AggCall::over)
290+
.map(RelBuilder.OverCall::toRex)
291+
.flatMap(rex -> getInputRefs(rex).stream())
292+
.toList();
293+
}
262294
}

integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLExistsSubqueryIT.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.opensearch.sql.util.MatcherUtils.verifyDataRowsInOrder;
1515
import static org.opensearch.sql.util.MatcherUtils.verifyNumOfRows;
1616
import static org.opensearch.sql.util.MatcherUtils.verifySchema;
17+
import static org.opensearch.sql.util.MatcherUtils.verifySchemaInOrder;
1718

1819
import java.io.IOException;
1920
import org.json.JSONObject;
@@ -293,4 +294,23 @@ public void testExistsSubqueryWithConjunction() {
293294
result, schema("id", "integer"), schema("name", "string"), schema("salary", "integer"));
294295
verifyDataRowsInOrder(result, rows(1003, "David", 120000), rows(1000, "Jake", 100000));
295296
}
297+
298+
@Test
299+
public void testIssue3566() {
300+
JSONObject result =
301+
executeQuery(
302+
String.format(
303+
"""
304+
source = %s
305+
| fields id, country
306+
| where exists [
307+
source = %s
308+
| where id = uid
309+
]
310+
| stats count() by country
311+
""",
312+
TEST_INDEX_WORKER, TEST_INDEX_WORK_INFORMATION));
313+
verifySchemaInOrder(result, schema("count()", "long"), schema("country", "string"));
314+
verifyDataRows(result, rows(1, null), rows(1, "England"), rows(1, "USA"), rows(2, "Canada"));
315+
}
296316
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"calcite": {
3-
"logical": "LogicalProject(avg_age=[$2], state=[$1], city=[$0])\n LogicalAggregate(group=[{5, 7}], avg_age=[AVG($8)])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4-
"physical": "EnumerableCalc(expr#0..2=[{inputs}], avg_age=[$t2], state=[$t1], city=[$t0])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[city, state, age], FILTER->>($2, 30), AGGREGATION->rel#12051:LogicalAggregate.NONE.[](input=RelSubset#12050,group={0, 1},avg_age=AVG($2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"city\",\"state\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n"
3+
"logical": "LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4+
"physical": "EnumerableCalc(expr#0..2=[{inputs}], avg_age=[$t2], state=[$t1], city=[$t0])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[city, state, age], FILTER->>($2, 30), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},avg_age=AVG($2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"city\",\"state\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n"
55
}
6-
}
6+
}

integ-test/src/test/resources/expectedOutput/calcite/explain_filter_push.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
"logical": "LogicalProject(age=[$8])\n LogicalFilter(condition=[>($3, 10000)])\n LogicalFilter(condition=[<($8, 40)])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
44
"physical": "CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[balance, age], FILTER->>($1, 30), FILTER-><($1, 40), FILTER->>($0, 10000), PROJECT->[age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"query\":{\"bool\":{\"filter\":[{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":null,\"to\":40,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}},{\"range\":{\"balance\":{\"from\":10000,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n"
55
}
6-
}
6+
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"calcite": {
3-
"logical": "LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC])\n LogicalProject(avg_age=[$2], state=[$1], city=[$0])\n LogicalAggregate(group=[{5, 7}], avg_age=[AVG($8)])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4-
"physical": "EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], age2=[$t1], $condition=[$t4])\n EnumerableWindow(window#0=[window(partition {1} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..2=[{inputs}], expr#3=[2], expr#4=[+($t2, $t3)], expr#5=[IS NOT NULL($t2)], state=[$t1], age2=[$t4], $condition=[$t5])\n EnumerableSort(sort0=[$1], dir0=[ASC])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[city, state, age], FILTER->>($2, 30), AGGREGATION->rel#11061:LogicalAggregate.NONE.[](input=RelSubset#11060,group={0, 1},avg_age=AVG($2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"city\",\"state\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n"
3+
"logical": "LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4+
"physical": "EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], age2=[$t1], $condition=[$t4])\n EnumerableWindow(window#0=[window(partition {1} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..2=[{inputs}], expr#3=[2], expr#4=[+($t2, $t3)], expr#5=[IS NOT NULL($t2)], state=[$t0], age2=[$t4], $condition=[$t5])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[FILTER->>($8, 30), PROJECT->[state, city, age], AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},avg_age=AVG($2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"state\",\"city\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n"
55
}
6-
}
6+
}

opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ private static Pair<Builder, List<MetricParser>> processAggregateCalls(
147147
List<AggregateCall> aggCalls,
148148
FieldExpressionCreator fieldExpressionCreator,
149149
List<String> outputFields) {
150+
assert aggCalls.size() + groupOffset == outputFields.size()
151+
: "groups size and agg calls size should match with output fields";
150152
Builder metricBuilder = new AggregatorFactories.Builder();
151153
List<MetricParser> metricParserList = new ArrayList<>();
152154

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
import java.util.Map;
2020
import java.util.stream.Collectors;
2121
import lombok.EqualsAndHashCode;
22+
import lombok.Getter;
2223
import org.opensearch.search.aggregations.Aggregations;
2324
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation;
2425

2526
/** Composite Aggregation Parser which include composite aggregation and metric parsers. */
27+
@Getter
2628
@EqualsAndHashCode
2729
public class CompositeAggregationParser implements OpenSearchAggregationResponseParser {
2830

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
import java.util.Map;
1919
import java.util.stream.Collectors;
2020
import lombok.EqualsAndHashCode;
21+
import lombok.Getter;
2122
import lombok.RequiredArgsConstructor;
2223
import org.opensearch.search.aggregations.Aggregation;
2324
import org.opensearch.search.aggregations.Aggregations;
2425
import org.opensearch.sql.common.utils.StringUtils;
2526

2627
/** Parse multiple metrics in one bucket. */
28+
@Getter
2729
@EqualsAndHashCode
2830
@RequiredArgsConstructor
2931
public class MetricParserHelper {

0 commit comments

Comments
 (0)