Skip to content

Commit c4ae5ca

Browse files
committed
Merge multiple sort when occurs
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent fb380de commit c4ae5ca

3 files changed

Lines changed: 43 additions & 4 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchSortIndexScanRule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public interface Config extends RelRule.Config {
4343
.oneInput(
4444
b1 ->
4545
b1.operand(CalciteLogicalIndexScan.class)
46-
.predicate(OpenSearchIndexScanRule::test)
46+
.predicate(OpenSearchIndexScanRule::noAggregatePushed)
4747
.noInputs()));
4848

4949
@Override

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public double estimateRowCount(RelMetadataQuery mq) {
8585
(rowCount, action) ->
8686
switch (action.type) {
8787
case AGGREGATION -> mq.getRowCount((RelNode) action.digest);
88-
case PROJECT -> rowCount;
88+
case PROJECT, SORT -> rowCount;
8989
case FILTER -> NumberUtil.multiply(
9090
rowCount, RelMdUtil.guessSelectivity((RexNode) action.digest));
9191
case LIMIT -> (Integer) action.digest;

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,20 @@
77

88
import com.google.common.collect.ImmutableList;
99
import java.util.ArrayList;
10+
import java.util.LinkedList;
1011
import java.util.List;
1112
import java.util.Map;
13+
import java.util.OptionalInt;
1214
import java.util.stream.Collectors;
15+
import java.util.stream.IntStream;
1316
import lombok.Getter;
1417
import org.apache.calcite.plan.Convention;
1518
import org.apache.calcite.plan.RelOptCluster;
1619
import org.apache.calcite.plan.RelOptPlanner;
1720
import org.apache.calcite.plan.RelOptRule;
1821
import org.apache.calcite.plan.RelOptTable;
1922
import org.apache.calcite.plan.RelTraitSet;
23+
import org.apache.calcite.rel.RelCollation;
2024
import org.apache.calcite.rel.RelCollations;
2125
import org.apache.calcite.rel.RelFieldCollation;
2226
import org.apache.calcite.rel.core.Aggregate;
@@ -199,8 +203,15 @@ public CalciteLogicalIndexScan pushDownLimit(Integer limit, Integer offset) {
199203

200204
public CalciteLogicalIndexScan pushDownSort(List<RelFieldCollation> collations) {
201205
try {
206+
// Merge with existing sort if any
207+
RelCollation existingCollation = getTraitSet().getCollation();
208+
List<RelFieldCollation> existingFieldCollations =
209+
existingCollation == null ? List.of() : existingCollation.getFieldCollations();
210+
List<RelFieldCollation> mergedCollations =
211+
mergeCollations(existingFieldCollations, collations);
212+
202213
// Propagate the sort to the new scan
203-
RelTraitSet newTraitSet = getTraitSet().plus(RelCollations.of(collations));
214+
RelTraitSet newTraitSet = getTraitSet().plus(RelCollations.of(mergedCollations));
204215
CalciteLogicalIndexScan newScan =
205216
new CalciteLogicalIndexScan(
206217
getCluster(),
@@ -212,7 +223,7 @@ public CalciteLogicalIndexScan pushDownSort(List<RelFieldCollation> collations)
212223
pushDownContext.clone());
213224

214225
List<SortBuilder<?>> builders = new ArrayList<>();
215-
for (RelFieldCollation collation : collations) {
226+
for (RelFieldCollation collation : mergedCollations) {
216227
int index = collation.getFieldIndex();
217228
String fieldName = this.getRowType().getFieldNames().get(index);
218229
RelFieldCollation.Direction direction = collation.getDirection();
@@ -245,4 +256,32 @@ public CalciteLogicalIndexScan pushDownSort(List<RelFieldCollation> collations)
245256
}
246257
return null;
247258
}
259+
260+
/**
261+
* Merges existing and new collations, ensuring that the last occurrence of each field index takes
262+
* precedence.
263+
*
264+
* @param existingCollations Existing collation list.
265+
* @param newCollations New collation list to be merged.
266+
* @return Merged list of collations.
267+
*/
268+
private static List<RelFieldCollation> mergeCollations(
269+
List<RelFieldCollation> existingCollations, List<RelFieldCollation> newCollations) {
270+
List<RelFieldCollation> concatenatedCollations = new ArrayList<>(existingCollations);
271+
concatenatedCollations.addAll(newCollations);
272+
LinkedList<RelFieldCollation> mergedCollations = new LinkedList<>();
273+
for (RelFieldCollation collation : concatenatedCollations) {
274+
// If the collation is already in the merged list, remove it from the list before adding
275+
// This is because the sort that comes later in the list should take precedence
276+
OptionalInt index =
277+
IntStream.range(0, mergedCollations.size())
278+
.filter(i -> mergedCollations.get(i).getFieldIndex() == collation.getFieldIndex())
279+
.findFirst();
280+
if (index.isPresent()) {
281+
mergedCollations.remove(index.getAsInt());
282+
}
283+
mergedCollations.add(collation);
284+
}
285+
return mergedCollations;
286+
}
248287
}

0 commit comments

Comments
 (0)