Skip to content

Commit 9f1ee08

Browse files
authored
Support pushdown sort by simple expressions (opensearch-project#4071)
* Support pushdown sort by simple expressions Signed-off-by: Songkan Tang <songkant@amazon.com> * Fix IT for no pushdown case Signed-off-by: Songkan Tang <songkant@amazon.com> * Add minor case to allow sort pushdown for casted floating number Signed-off-by: Songkan Tang <songkant@amazon.com> * Fix the issue of using wrong fromCollation Signed-off-by: Songkan Tang <songkant@amazon.com> * Add some unit tests for OpenSearchRelOptUtil Signed-off-by: Songkan Tang <songkant@amazon.com> * Fix checkstyle Signed-off-by: Songkan Tang <songkant@amazon.com> --------- Signed-off-by: Songkan Tang <songkant@amazon.com>
1 parent e2678a1 commit 9f1ee08

12 files changed

Lines changed: 860 additions & 1 deletion

File tree

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,24 @@ public void testRegexNegatedExplain() throws IOException {
413413
assertJsonEqualsIgnoreId(expected, result);
414414
}
415415

416+
@Test
417+
public void testSimpleSortExpressionPushDownExplain() throws Exception {
418+
String query =
419+
"source=opensearch-sql_test_index_bank| eval age2 = age + 2 | sort age2 | fields age, age2";
420+
var result = explainQueryToString(query);
421+
String expected = loadExpectedPlan("explain_simple_sort_expr_push.json");
422+
assertJsonEqualsIgnoreId(expected, result);
423+
}
424+
425+
@Test
426+
public void testSimpleSortExpressionPushDownWithOnlyExprProjected() throws Exception {
427+
String query =
428+
"source=opensearch-sql_test_index_bank| eval b = balance + 1 | sort b | fields b";
429+
var result = explainQueryToString(query);
430+
String expected = loadExpectedPlan("explain_simple_sort_expr_single_expr_output_push.json");
431+
assertJsonEqualsIgnoreId(expected, result);
432+
}
433+
416434
/**
417435
* Executes the PPL query and returns the result as a string with windows-style line breaks
418436
* replaced with Unix-style ones.

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteSortCommandIT.java

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
99
import static org.opensearch.sql.util.MatcherUtils.rows;
10+
import static org.opensearch.sql.util.MatcherUtils.schema;
1011
import static org.opensearch.sql.util.MatcherUtils.verifyOrder;
12+
import static org.opensearch.sql.util.MatcherUtils.verifySchema;
1113

1214
import java.io.IOException;
15+
import java.util.Locale;
1316
import org.json.JSONObject;
1417
import org.junit.Test;
1518
import org.opensearch.sql.ppl.SortCommandIT;
@@ -28,4 +31,148 @@ public void testHeadThenSort() throws IOException {
2831
executeQuery(String.format("source=%s | head 2 | sort age | fields age", TEST_INDEX_BANK));
2932
verifyOrder(result, rows(32), rows(36));
3033
}
34+
35+
@Test
36+
public void testPushdownSortPlusExpression() throws IOException {
37+
String ppl =
38+
String.format(
39+
Locale.ROOT,
40+
"source=%s | eval age2 = age + 2 | sort age2 | fields age | head 2",
41+
TEST_INDEX_BANK);
42+
String explained = explainQueryToString(ppl);
43+
if (isPushdownEnabled()) {
44+
assertTrue(
45+
explained.contains(
46+
"[SORT->[{\\n"
47+
+ " \\\"age\\\" : {\\n"
48+
+ " \\\"order\\\" : \\\"asc\\\",\\n"
49+
+ " \\\"missing\\\" : \\\"_first\\\"\\n"
50+
+ " }\\n"
51+
+ "}]"));
52+
}
53+
54+
JSONObject result = executeQuery(ppl);
55+
verifyOrder(result, rows(28), rows(32));
56+
}
57+
58+
@Test
59+
public void testPushdownSortMinusExpression() throws IOException {
60+
String ppl =
61+
String.format(
62+
Locale.ROOT,
63+
"source=%s | eval age2 = 1 - age | sort age2 | fields age | head 2",
64+
TEST_INDEX_BANK);
65+
String explained = explainQueryToString(ppl);
66+
if (isPushdownEnabled()) {
67+
assertTrue(
68+
explained.contains(
69+
"[SORT->[{\\n"
70+
+ " \\\"age\\\" : {\\n"
71+
+ " \\\"order\\\" : \\\"desc\\\",\\n"
72+
+ " \\\"missing\\\" : \\\"_first\\\"\\n"
73+
+ " }\\n"
74+
+ "}]"));
75+
}
76+
77+
JSONObject result = executeQuery(ppl);
78+
verifyOrder(result, rows(39), rows(36));
79+
}
80+
81+
@Test
82+
public void testPushdownSortTimesExpression() throws IOException {
83+
String ppl =
84+
String.format(
85+
Locale.ROOT,
86+
"source=%s | eval age2 = 5 * age | sort age2 | fields age | head 2",
87+
TEST_INDEX_BANK);
88+
String explained = explainQueryToString(ppl);
89+
if (isPushdownEnabled()) {
90+
assertTrue(
91+
explained.contains(
92+
"[SORT->[{\\n"
93+
+ " \\\"age\\\" : {\\n"
94+
+ " \\\"order\\\" : \\\"asc\\\",\\n"
95+
+ " \\\"missing\\\" : \\\"_first\\\"\\n"
96+
+ " }\\n"
97+
+ "}]"));
98+
}
99+
100+
JSONObject result = executeQuery(ppl);
101+
verifyOrder(result, rows(28), rows(32));
102+
}
103+
104+
@Test
105+
public void testPushdownSortByMultiExpressions() throws IOException {
106+
String ppl =
107+
String.format(
108+
Locale.ROOT,
109+
"source=%s | eval age2 = 5 * age | sort gender, age2 | fields gender, age | head 2",
110+
TEST_INDEX_BANK);
111+
String explained = explainQueryToString(ppl);
112+
if (isPushdownEnabled()) {
113+
assertTrue(
114+
explained.contains(
115+
"[SORT->[{\\n"
116+
+ " \\\"gender.keyword\\\" : {\\n"
117+
+ " \\\"order\\\" : \\\"asc\\\",\\n"
118+
+ " \\\"missing\\\" : \\\"_first\\\"\\n"
119+
+ " }\\n"
120+
+ "}, {\\n"
121+
+ " \\\"age\\\" : {\\n"
122+
+ " \\\"order\\\" : \\\"asc\\\",\\n"
123+
+ " \\\"missing\\\" : \\\"_first\\\"\\n"
124+
+ " }\\n"
125+
+ "}]"));
126+
}
127+
128+
JSONObject result = executeQuery(ppl);
129+
verifyOrder(result, rows("F", 28), rows("F", 34));
130+
}
131+
132+
@Test
133+
public void testPushdownSortCastExpression() throws IOException {
134+
String ppl =
135+
String.format(
136+
Locale.ROOT,
137+
"source=%s | eval age2 = cast(age * 5 as long) | sort age2 | fields age | head 2",
138+
TEST_INDEX_BANK);
139+
String explained = explainQueryToString(ppl);
140+
if (isPushdownEnabled()) {
141+
assertTrue(
142+
explained.contains(
143+
"[SORT->[{\\n"
144+
+ " \\\"age\\\" : {\\n"
145+
+ " \\\"order\\\" : \\\"asc\\\",\\n"
146+
+ " \\\"missing\\\" : \\\"_first\\\"\\n"
147+
+ " }\\n"
148+
+ "}]"));
149+
}
150+
151+
JSONObject result = executeQuery(ppl);
152+
verifyOrder(result, rows(28), rows(32));
153+
}
154+
155+
@Test
156+
public void testPushdownSortCastToDoubleExpression() throws IOException {
157+
// Similar to query: 'source=%s | sort num(age)'. But left query doesn't output casted column.
158+
String ppl =
159+
String.format(
160+
"source=%s | eval age2 = cast(age as double) | sort age2 | fields age, age2 | head 2",
161+
TEST_INDEX_BANK);
162+
String explained = explainQueryToString(ppl);
163+
if (isPushdownEnabled()) {
164+
assertTrue(
165+
explained.contains(
166+
"SORT->[{\\n"
167+
+ " \\\"age\\\" : {\\n"
168+
+ " \\\"order\\\" : \\\"asc\\\",\\n"
169+
+ " \\\"missing\\\" : \\\"_first\\\"\\n"
170+
+ " }\\n"
171+
+ "}]"));
172+
}
173+
174+
JSONObject result = executeQuery(ppl);
175+
verifySchema(result, schema("age", "int"), schema("age2", "double"));
176+
verifyOrder(result, rows(28, 28d), rows(32, 32d));
177+
}
31178
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"calcite": {
3+
"logical": "LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(age=[$10], age2=[$19])\n LogicalSort(sort0=[$19], dir0=[ASC-nulls-first])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _id=[$13], _index=[$14], _score=[$15], _maxscore=[$16], _sort=[$17], _routing=[$18], age2=[+($10, 2)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n",
4+
"physical": "EnumerableCalc(expr#0=[{inputs}], expr#1=[2], expr#2=[+($t0, $t1)], age=[$t0], $f1=[$t2])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[SORT->[{\n \"age\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_first\"\n }\n}], LIMIT->10000, PROJECT->[age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"age\"],\"excludes\":[]},\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n"
5+
}
6+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"calcite": {
3+
"logical": "LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(b=[$19])\n LogicalSort(sort0=[$19], dir0=[ASC-nulls-first])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _id=[$13], _index=[$14], _score=[$15], _maxscore=[$16], _sort=[$17], _routing=[$18], b=[+($7, 1)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n",
4+
"physical": "EnumerableCalc(expr#0=[{inputs}], expr#1=[1], expr#2=[+($t0, $t1)], $f0=[$t2])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[SORT->[{\n \"balance\" : {\n \"order\" : \"asc\",\n \"missing\" : \"_first\"\n }\n}], LIMIT->10000, PROJECT->[balance]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":10000,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"balance\"],\"excludes\":[]},\"sort\":[{\"balance\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, requestedTotalSize=10000, pageSize=null, startFrom=0)])\n"
5+
}
6+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"calcite": {
3+
"logical": "LogicalSystemLimit(sort0=[$1], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(age=[$10], age2=[$19])\n LogicalSort(sort0=[$19], dir0=[ASC-nulls-first])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _id=[$13], _index=[$14], _score=[$15], _maxscore=[$16], _sort=[$17], _routing=[$18], age2=[+($10, 2)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n",
4+
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableSort(sort0=[$1], dir0=[ASC-nulls-first])\n EnumerableCalc(expr#0..18=[{inputs}], expr#19=[2], expr#20=[+($t10, $t19)], age=[$t10], age2=[$t20])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n"
5+
}
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"calcite": {
3+
"logical": "LogicalSystemLimit(sort0=[$0], dir0=[ASC-nulls-first], fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(b=[$19])\n LogicalSort(sort0=[$19], dir0=[ASC-nulls-first])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], birthdate=[$3], gender=[$4], city=[$5], lastname=[$6], balance=[$7], employer=[$8], state=[$9], age=[$10], email=[$11], male=[$12], _id=[$13], _index=[$14], _score=[$15], _maxscore=[$16], _sort=[$17], _routing=[$18], b=[+($7, 1)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n",
4+
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableSort(sort0=[$0], dir0=[ASC-nulls-first])\n EnumerableCalc(expr#0..18=[{inputs}], expr#19=[1], expr#20=[+($t7, $t19)], b=[$t20])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n"
5+
}
6+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.opensearch.planner.physical;
7+
8+
import java.util.Optional;
9+
import org.apache.calcite.adapter.enumerable.EnumerableProject;
10+
import org.apache.calcite.plan.RelOptRuleCall;
11+
import org.apache.calcite.plan.RelRule;
12+
import org.apache.calcite.plan.RelTrait;
13+
import org.apache.calcite.plan.RelTraitSet;
14+
import org.apache.calcite.plan.volcano.AbstractConverter;
15+
import org.apache.calcite.rel.RelCollation;
16+
import org.apache.calcite.rel.RelCollationTraitDef;
17+
import org.apache.calcite.rel.RelFieldCollation;
18+
import org.apache.calcite.rel.RelFieldCollation.Direction;
19+
import org.apache.calcite.rel.core.Project;
20+
import org.apache.commons.lang3.tuple.Pair;
21+
import org.immutables.value.Value;
22+
import org.opensearch.sql.opensearch.util.OpenSearchRelOptUtil;
23+
24+
/**
25+
* When ENUMERABLE convention physical node is converted from logical node, each enumerable node's
26+
* collation is recalculated based on input collations. However, if SortProjectExprTransposeRule
27+
* takes effect, the input collation is changed to a sort over field instead of original sort over
28+
* expression. It changes the collation requirement of the whole query.
29+
*
30+
* <p>AbstractConverter physical node is supposed to resolve the problem of inconsistent collation
31+
* requirement between physical node input and output. This optimization rule finds equivalent
32+
* output expression collations and input field collations. If their collation traits are satisfied,
33+
* generate a new RelSubset without top sort
34+
*/
35+
@Value.Enclosing
36+
public class ExpandCollationOnProjectExprRule
37+
extends RelRule<ExpandCollationOnProjectExprRule.Config> {
38+
39+
protected ExpandCollationOnProjectExprRule(Config config) {
40+
super(config);
41+
}
42+
43+
@Override
44+
public void onMatch(RelOptRuleCall call) {
45+
final AbstractConverter converter = call.rel(0);
46+
final Project project = call.rel(1);
47+
final RelTraitSet toTraits = converter.getTraitSet();
48+
final RelCollation toCollation = toTraits.getTrait(RelCollationTraitDef.INSTANCE);
49+
final RelTrait fromTrait =
50+
project.getInput().getTraitSet().getTrait(RelCollationTraitDef.INSTANCE);
51+
// In case of fromTrait is an instance of RelCompositeTrait, it most likely finds equivalence by
52+
// default.
53+
// Let it go through default ExpandConversionRule to determine trait satisfaction.
54+
if (fromTrait != null && fromTrait instanceof RelCollation) {
55+
RelCollation fromCollation = (RelCollation) fromTrait;
56+
// TODO: Handle the case where multi expr collations are mapped to the same source field
57+
if (toCollation == null
58+
|| toCollation.getFieldCollations().isEmpty()
59+
|| fromCollation == null
60+
|| fromCollation.getFieldCollations().size() < toCollation.getFieldCollations().size()) {
61+
return;
62+
}
63+
64+
for (int i = 0; i < toCollation.getFieldCollations().size(); i++) {
65+
RelFieldCollation targetFieldCollation = toCollation.getFieldCollations().get(i);
66+
Optional<Pair<Integer, Boolean>> equivalentCollationInputInfo =
67+
OpenSearchRelOptUtil.getOrderEquivalentInputInfo(
68+
project.getProjects().get(targetFieldCollation.getFieldIndex()));
69+
70+
if (equivalentCollationInputInfo.isEmpty()) {
71+
return;
72+
}
73+
74+
RelFieldCollation sourceFieldCollation = fromCollation.getFieldCollations().get(i);
75+
int equivalentSourceIndex = equivalentCollationInputInfo.get().getLeft();
76+
Direction equivalentSourceDirection =
77+
equivalentCollationInputInfo.get().getRight()
78+
? targetFieldCollation.getDirection().reverse()
79+
: targetFieldCollation.getDirection();
80+
if (!(equivalentSourceIndex == sourceFieldCollation.getFieldIndex()
81+
&& equivalentSourceDirection == sourceFieldCollation.getDirection())) {
82+
return;
83+
}
84+
}
85+
86+
// After collation equivalence analysis, fromTrait satisfies toTrait. Copy the target trait
87+
// set
88+
// to new EnumerableProject.
89+
Project newProject =
90+
project.copy(toTraits, project.getInput(), project.getProjects(), project.getRowType());
91+
call.transformTo(newProject);
92+
}
93+
}
94+
95+
@Value.Immutable
96+
public interface Config extends RelRule.Config {
97+
98+
/**
99+
* Only match ENUMERABLE convention RelNode combination like below to narrow the optimization
100+
* searching space: - AbstractConverter - EnumerableProject
101+
*/
102+
ExpandCollationOnProjectExprRule.Config DEFAULT =
103+
ImmutableExpandCollationOnProjectExprRule.Config.builder()
104+
.build()
105+
.withOperandSupplier(
106+
b0 ->
107+
b0.operand(AbstractConverter.class)
108+
.oneInput(
109+
b1 ->
110+
b1.operand(EnumerableProject.class)
111+
.predicate(OpenSearchIndexScanRule::projectContainsExpr)
112+
.predicate(p -> !p.containsOver())
113+
.anyInputs()));
114+
115+
@Override
116+
default ExpandCollationOnProjectExprRule toRule() {
117+
return new ExpandCollationOnProjectExprRule(this);
118+
}
119+
}
120+
}

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ public class OpenSearchIndexRules {
2424
OpenSearchSortIndexScanRule.Config.DEFAULT.toRule();
2525
private static final OpenSearchDedupPushdownRule DEDUP_PUSH_DOWN =
2626
OpenSearchDedupPushdownRule.Config.DEFAULT.toRule();
27+
private static final SortProjectExprTransposeRule SORT_PROJECT_EXPR_TRANSPOSE =
28+
SortProjectExprTransposeRule.Config.DEFAULT.toRule();
29+
private static final ExpandCollationOnProjectExprRule EXPAND_COLLATION_ON_PROJECT_EXPR =
30+
ExpandCollationOnProjectExprRule.Config.DEFAULT.toRule();
2731

2832
public static final List<RelOptRule> OPEN_SEARCH_INDEX_SCAN_RULES =
2933
ImmutableList.of(
@@ -33,7 +37,9 @@ public class OpenSearchIndexRules {
3337
COUNT_STAR_INDEX_SCAN,
3438
LIMIT_INDEX_SCAN,
3539
SORT_INDEX_SCAN,
36-
DEDUP_PUSH_DOWN);
40+
DEDUP_PUSH_DOWN,
41+
SORT_PROJECT_EXPR_TRANSPOSE,
42+
EXPAND_COLLATION_ON_PROJECT_EXPR);
3743

3844
// prevent instantiation
3945
private OpenSearchIndexRules() {}

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexScanRule.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import java.util.HashSet;
99
import java.util.Set;
1010
import org.apache.calcite.plan.RelOptTable;
11+
import org.apache.calcite.rel.core.Project;
1112
import org.apache.calcite.rel.core.Sort;
1213
import org.apache.calcite.rel.logical.LogicalProject;
1314
import org.apache.calcite.rel.logical.LogicalSort;
15+
import org.apache.calcite.rex.RexCall;
1416
import org.apache.calcite.rex.RexNode;
1517
import org.apache.calcite.rex.RexOver;
1618
import org.opensearch.sql.opensearch.storage.OpenSearchIndex;
@@ -59,6 +61,10 @@ static boolean isLogicalSortLimit(LogicalSort sort) {
5961
return sort.fetch != null;
6062
}
6163

64+
static boolean projectContainsExpr(Project project) {
65+
return project.getProjects().stream().anyMatch(p -> p instanceof RexCall);
66+
}
67+
6268
static boolean sortByFieldsOnly(Sort sort) {
6369
return !sort.getCollation().getFieldCollations().isEmpty() && sort.fetch == null;
6470
}

0 commit comments

Comments
 (0)