Skip to content

Commit 70b17ca

Browse files
committed
Support sort pushdown with aggregation when aggregated fields are not in sort by fields
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 0d46563 commit 70b17ca

7 files changed

Lines changed: 104 additions & 16 deletions

File tree

integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ public void testFilterAndAggPushDownExplain() throws IOException {
7474

7575
@Test
7676
public void testSortPushDownExplain() throws IOException {
77-
// TODO fix after https://github.com/opensearch-project/sql/issues/3380
7877
String expected =
7978
isCalciteEnabled()
8079
? loadFromFile("expectedOutput/calcite/explain_sort_push.json")
@@ -89,6 +88,25 @@ public void testSortPushDownExplain() throws IOException {
8988
+ "| fields age"));
9089
}
9190

91+
@Test
92+
public void testSortWithAggregationExplain() throws IOException {
93+
// Sorts whose by fields are aggregators should not be pushed down
94+
String expected =
95+
isCalciteEnabled()
96+
? loadFromFile("expectedOutput/calcite/explain_sort_agg_push.json")
97+
: loadFromFile("expectedOutput/ppl/explain_sort_agg_push.json");
98+
99+
assertJsonEqualsIgnoreId(
100+
expected,
101+
explainQueryToString(
102+
"source=opensearch-sql_test_index_account"
103+
+ "| stats avg(age) AS avg_age by state, city "
104+
+ "| sort avg_age "));
105+
106+
// sorts whose by fields are not aggregators can be pushed down.
107+
// This test is covered in testExplain
108+
}
109+
92110
@Test
93111
public void testLimitPushDownExplain() throws IOException {
94112
String expected =
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"calcite": {
3-
"logical": "LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4-
"physical": "EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], age2=[$t1], $condition=[$t4])\n EnumerableWindow(window#0=[window(partition {1} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..2=[{inputs}], expr#3=[2], expr#4=[+($t2, $t3)], expr#5=[IS NOT NULL($t2)], state=[$t0], age2=[$t4], $condition=[$t5])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[FILTER->>($8, 30), PROJECT->[state, city, age], AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},avg_age=AVG($2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"state\",\"city\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
3+
"logical":"LogicalProject(age2=[$2])\n LogicalFilter(condition=[<=($3, 1)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[$2], _row_number_=[ROW_NUMBER() OVER (PARTITION BY $2 ORDER BY $2)])\n LogicalFilter(condition=[IS NOT NULL($2)])\n LogicalProject(avg_age=[$0], state=[$1], age2=[+($0, 2)])\n LogicalSort(sort0=[$1], dir0=[ASC])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n LogicalFilter(condition=[>($8, 30)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4+
"physical":"EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], age2=[$t1], $condition=[$t4])\n EnumerableWindow(window#0=[window(partition {1} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])\n EnumerableCalc(expr#0..2=[{inputs}], expr#3=[2], expr#4=[+($t2, $t3)], expr#5=[IS NOT NULL($t2)], state=[$t1], age2=[$t4], $condition=[$t5])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[city, state, age], FILTER->>($2, 30), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 1},avg_age=AVG($2)), SORT->[{\n \"state\" : {\n \"order\" : \"asc\"\n }\n}]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"query\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"_source\":{\"includes\":[\"city\",\"state\",\"age\"],\"excludes\":[]},\"sort\":[{\"state\":{\"order\":\"asc\"}}],\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
55
}
6-
}
6+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"calcite": {
3+
"logical": "LogicalSort(sort0=[$0], dir0=[ASC])\n LogicalProject(avg_age=[$2], state=[$0], city=[$1])\n LogicalAggregate(group=[{0, 1}], avg_age=[AVG($2)])\n LogicalProject(state=[$7], city=[$5], age=[$8])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4+
"physical": "EnumerableSort(sort0=[$0], dir0=[ASC])\n EnumerableCalc(expr#0..2=[{inputs}], avg_age=[$t2], state=[$t0], city=[$t1])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[state, city, age], AGGREGATION->rel#16745:LogicalAggregate.NONE.[](input=RelSubset#16744,group={0, 1},avg_age=AVG($2))], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"state\",\"city\",\"age\"],\"excludes\":[]},\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
5+
}
6+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"root": {
3+
"name": "ProjectOperator",
4+
"description": {
5+
"fields": "[avg_age, state, city]"
6+
},
7+
"children": [
8+
{
9+
"name": "SortOperator",
10+
"description": {
11+
"sortList": {
12+
"avg_age": {
13+
"sortOrder": "ASC",
14+
"nullOrder": "NULL_FIRST"
15+
}
16+
}
17+
},
18+
"children": [
19+
{
20+
"name": "OpenSearchIndexScan",
21+
"description": {
22+
"request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"state\":{\"terms\":{\"field\":\"state.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}},{\"city\":{\"terms\":{\"field\":\"city.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"avg_age\":{\"avg\":{\"field\":\"age\"}}}}}}, needClean=true, searchDone=false, pitId=null, cursorKeepAlive=null, searchAfter=null, searchResponse=null)"
23+
},
24+
"children": []
25+
}
26+
]
27+
}
28+
]
29+
}
30+
}

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ public interface Config extends RelRule.Config {
4040
b0 ->
4141
b0.operand(LogicalSort.class)
4242
.predicate(OpenSearchIndexScanRule::sortByFieldsOnly)
43-
.oneInput(
44-
b1 ->
45-
b1.operand(CalciteLogicalIndexScan.class)
46-
.predicate(OpenSearchIndexScanRule::noAggregatePushed)
47-
.noInputs()));
43+
.oneInput(b1 -> b1.operand(CalciteLogicalIndexScan.class).noInputs()));
4844

4945
@Override
5046
default OpenSearchSortIndexScanRule toRule() {

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public double estimateRowCount(RelMetadataQuery mq) {
9898
public static class PushDownContext extends ArrayDeque<PushDownAction> {
9999

100100
private boolean isAggregatePushed = false;
101-
private boolean isLimitPushed = false;
101+
@Getter private boolean isLimitPushed = false;
102102

103103
@Override
104104
public PushDownContext clone() {
@@ -107,8 +107,6 @@ public PushDownContext clone() {
107107

108108
@Override
109109
public boolean add(PushDownAction pushDownAction) {
110-
// Defense check. It should never do push down to this context after aggregate push-down.
111-
assert !isAggregatePushed : "Aggregate has already been pushed!";
112110
if (pushDownAction.type == PushDownType.AGGREGATION) {
113111
isAggregatePushed = true;
114112
}
@@ -123,10 +121,6 @@ public boolean isAggregatePushed() {
123121
isAggregatePushed = !isEmpty() && super.peekLast().type == PushDownType.AGGREGATION;
124122
return isAggregatePushed;
125123
}
126-
127-
public boolean isLimitPushed() {
128-
return isLimitPushed;
129-
}
130124
}
131125

132126
protected enum PushDownType {

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.Map;
1313
import java.util.Objects;
1414
import java.util.stream.Collectors;
15+
import java.util.stream.Stream;
1516
import lombok.Getter;
1617
import org.apache.calcite.plan.Convention;
1718
import org.apache.calcite.plan.RelOptCluster;
@@ -25,6 +26,7 @@
2526
import org.apache.calcite.rel.core.Aggregate;
2627
import org.apache.calcite.rel.core.Filter;
2728
import org.apache.calcite.rel.hint.RelHint;
29+
import org.apache.calcite.rel.logical.LogicalAggregate;
2830
import org.apache.calcite.rel.type.RelDataType;
2931
import org.apache.calcite.rel.type.RelDataTypeFactory;
3032
import org.apache.calcite.rel.type.RelDataTypeField;
@@ -237,6 +239,17 @@ public CalciteLogicalIndexScan pushDownLimit(Integer limit, Integer offset) {
237239

238240
public CalciteLogicalIndexScan pushDownSort(List<RelFieldCollation> collations) {
239241
try {
242+
List<String> collationNames =
243+
collations.stream()
244+
.map(RelFieldCollation::getFieldIndex)
245+
.map(index -> this.getRowType().getFieldNames().get(index))
246+
.collect(Collectors.toList());
247+
if (getPushDownContext().isAggregatePushed() && hasAggregatorInSortBy(collationNames)) {
248+
// If aggregation is pushed down, we cannot push down sorts where its by fields contain
249+
// aggregators.
250+
return null;
251+
}
252+
240253
// Merge with existing sort if any
241254
RelCollation existingCollation = getTraitSet().getCollation();
242255
List<RelFieldCollation> existingFieldCollations =
@@ -312,4 +325,35 @@ private static List<RelFieldCollation> mergeCollations(
312325
}
313326
return new ArrayList<>(mergedCollations.values());
314327
}
328+
329+
/**
330+
* Check if the sort by collations contains any aggregators that are pushed down. E.g. In `stats
331+
* avg(age) as avg_age by state | sort avg_age`, the sort clause has `avg_age` which is an
332+
* aggregator. The function will return true in this case.
333+
*
334+
* @param collations List of collation names to check against aggregators.
335+
* @return True if any collation name matches an aggregator output, false otherwise.
336+
*/
337+
private boolean hasAggregatorInSortBy(List<String> collations) {
338+
Stream<LogicalAggregate> aggregates =
339+
pushDownContext.stream()
340+
.filter(action -> action.type() == PushDownType.AGGREGATION)
341+
.map(action -> ((LogicalAggregate) action.digest()));
342+
return aggregates
343+
.map(aggregate -> isAnyCollationNameInAggregateOutput(aggregate, collations))
344+
.reduce(false, Boolean::logicalOr);
345+
}
346+
347+
private static boolean isAnyCollationNameInAggregateOutput(
348+
LogicalAggregate aggregate, List<String> collations) {
349+
List<String> fieldNames = aggregate.getRowType().getFieldNames();
350+
// The output fields of the aggregate are in the format of
351+
// [...grouping fields, ...aggregator fields], so we set an offset to skip
352+
// the grouping fields.
353+
int groupOffset = aggregate.getGroupSet().cardinality();
354+
List<String> fieldsWithoutGrouping = fieldNames.subList(groupOffset, fieldNames.size());
355+
return collations.stream()
356+
.map(fieldsWithoutGrouping::contains)
357+
.reduce(false, Boolean::logicalOr);
358+
}
315359
}

0 commit comments

Comments
 (0)