|
5 | 5 |
|
6 | 6 | package org.opensearch.sql.calcite.utils; |
7 | 7 |
|
| 8 | +import java.util.ArrayList; |
8 | 9 | import java.util.LinkedHashSet; |
9 | 10 | import java.util.List; |
10 | 11 | import java.util.Set; |
| 12 | +import org.apache.calcite.linq4j.Ord; |
11 | 13 | import org.apache.calcite.plan.RelOptUtil; |
12 | 14 | import org.apache.calcite.rel.RelNode; |
| 15 | +import org.apache.calcite.rel.core.Aggregate; |
| 16 | +import org.apache.calcite.rel.core.CorrelationId; |
| 17 | +import org.apache.calcite.rel.core.Project; |
| 18 | +import org.apache.calcite.rel.core.Values; |
13 | 19 | import org.apache.calcite.rel.type.RelDataType; |
14 | 20 | import org.apache.calcite.rel.type.RelDataTypeField; |
| 21 | +import org.apache.calcite.rex.RexLiteral; |
15 | 22 | import org.apache.calcite.rex.RexNode; |
16 | 23 | import org.apache.calcite.rex.RexPermuteInputsShuttle; |
| 24 | +import org.apache.calcite.rex.RexSubQuery; |
| 25 | +import org.apache.calcite.rex.RexUtil; |
17 | 26 | import org.apache.calcite.rex.RexVisitor; |
18 | 27 | import org.apache.calcite.sql.validate.SqlValidator; |
19 | 28 | import org.apache.calcite.sql2rel.RelFieldTrimmer; |
20 | 29 | import org.apache.calcite.tools.RelBuilder; |
21 | 30 | import org.apache.calcite.util.ImmutableBitSet; |
22 | 31 | import org.apache.calcite.util.mapping.Mapping; |
| 32 | +import org.apache.calcite.util.mapping.MappingType; |
23 | 33 | import org.apache.calcite.util.mapping.Mappings; |
24 | 34 | import org.checkerframework.checker.nullness.qual.Nullable; |
25 | 35 | import org.opensearch.sql.calcite.plan.rel.Dedup; |
|
30 | 40 | * <p>This class extends Calcite's RelFieldTrimmer to support trimming customized operators. |
31 | 41 | */ |
32 | 42 | public class OpenSearchRelFieldTrimmer extends RelFieldTrimmer { |
| 43 | + private final RelBuilder openSearchRelBuilder; |
33 | 44 |
|
34 | 45 | public OpenSearchRelFieldTrimmer(@Nullable SqlValidator validator, RelBuilder relBuilder) { |
35 | 46 | super(validator, relBuilder); |
| 47 | + this.openSearchRelBuilder = relBuilder; |
| 48 | + } |
| 49 | + |
| 50 | + @Override |
| 51 | + public TrimResult trimFields( |
| 52 | + Project project, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) { |
| 53 | + final RelDataType rowType = project.getRowType(); |
| 54 | + final int fieldCount = rowType.getFieldCount(); |
| 55 | + final RelNode input = project.getInput(); |
| 56 | + |
| 57 | + final Set<RelDataTypeField> inputExtraFields = new LinkedHashSet<>(extraFields); |
| 58 | + RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields); |
| 59 | + for (Ord<RexNode> ord : Ord.zip(project.getProjects())) { |
| 60 | + if (fieldsUsed.get(ord.i)) { |
| 61 | + ord.e.accept(inputFinder); |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project); |
| 66 | + Set<CorrelationId> correlationIds = RelOptUtil.getVariablesUsed(subQueries); |
| 67 | + ImmutableBitSet requiredColumns = ImmutableBitSet.of(); |
| 68 | + if (!correlationIds.isEmpty()) { |
| 69 | + assert correlationIds.size() == 1; |
| 70 | + requiredColumns = RelOptUtil.correlationColumns(correlationIds.iterator().next(), project); |
| 71 | + } |
| 72 | + |
| 73 | + ImmutableBitSet finderFields = inputFinder.build(); |
| 74 | + ImmutableBitSet inputFieldsUsed = |
| 75 | + ImmutableBitSet.builder().addAll(requiredColumns).addAll(finderFields).build(); |
| 76 | + |
| 77 | + TrimResult trimResult = trimChild(project, input, inputFieldsUsed, inputExtraFields); |
| 78 | + RelNode newInput = trimResult.left; |
| 79 | + final Mapping inputMapping = trimResult.right; |
| 80 | + |
| 81 | + if (newInput == input && fieldsUsed.cardinality() == fieldCount) { |
| 82 | + return result(project, Mappings.createIdentity(fieldCount)); |
| 83 | + } |
| 84 | + |
| 85 | + if (fieldsUsed.cardinality() == 0) { |
| 86 | + return dummyProject(fieldCount, newInput, project); |
| 87 | + } |
| 88 | + |
| 89 | + final List<RexNode> newProjects = new ArrayList<>(); |
| 90 | + final RexVisitor<RexNode> shuttle; |
| 91 | + if (!correlationIds.isEmpty()) { |
| 92 | + assert correlationIds.size() == 1; |
| 93 | + shuttle = |
| 94 | + new RexPermuteInputsShuttle(inputMapping, newInput) { |
| 95 | + @Override |
| 96 | + public RexNode visitSubQuery(RexSubQuery subQuery) { |
| 97 | + subQuery = (RexSubQuery) super.visitSubQuery(subQuery); |
| 98 | + return RelOptUtil.remapCorrelatesInSuqQuery( |
| 99 | + openSearchRelBuilder.getRexBuilder(), |
| 100 | + subQuery, |
| 101 | + correlationIds.iterator().next(), |
| 102 | + newInput.getRowType(), |
| 103 | + inputMapping); |
| 104 | + } |
| 105 | + }; |
| 106 | + } else { |
| 107 | + shuttle = new RexPermuteInputsShuttle(inputMapping, newInput); |
| 108 | + } |
| 109 | + |
| 110 | + final Mapping mapping = |
| 111 | + Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, fieldsUsed.cardinality()); |
| 112 | + for (Ord<RexNode> ord : Ord.zip(project.getProjects())) { |
| 113 | + if (fieldsUsed.get(ord.i)) { |
| 114 | + mapping.set(ord.i, newProjects.size()); |
| 115 | + RexNode newProjectExpr = ord.e.accept(shuttle); |
| 116 | + newProjects.add(newProjectExpr); |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + final RelDataType newRowType = |
| 121 | + RelOptUtil.permute(project.getCluster().getTypeFactory(), rowType, mapping); |
| 122 | + |
| 123 | + if (shouldAvoidSimplifyValues(newProjects, newInput)) { |
| 124 | + return result( |
| 125 | + project.copy(project.getTraitSet(), newInput, newProjects, newRowType), mapping, project); |
| 126 | + } |
| 127 | + |
| 128 | + openSearchRelBuilder.push(newInput); |
| 129 | + openSearchRelBuilder.project(newProjects, newRowType.getFieldNames(), false, correlationIds); |
| 130 | + return result(openSearchRelBuilder.build(), mapping, project); |
36 | 131 | } |
37 | 132 |
|
38 | 133 | public TrimResult trimFields( |
@@ -67,4 +162,19 @@ public TrimResult trimFields( |
67 | 162 | // needs them for its condition. |
68 | 163 | return result(dedup.copy(newInput, newDedupFields), inputMapping); |
69 | 164 | } |
| 165 | + |
| 166 | + private boolean shouldAvoidSimplifyValues(List<RexNode> projects, RelNode input) { |
| 167 | + return projects.stream().allMatch(RexLiteral.class::isInstance) && isFixedRowCount(input); |
| 168 | + } |
| 169 | + |
| 170 | + private boolean isFixedRowCount(RelNode input) { |
| 171 | + if (input instanceof Values) { |
| 172 | + return true; |
| 173 | + } |
| 174 | + if (input instanceof Aggregate aggregate) { |
| 175 | + return aggregate.getGroupSet().isEmpty() |
| 176 | + && aggregate.getGroupType() == Aggregate.Group.SIMPLE; |
| 177 | + } |
| 178 | + return false; |
| 179 | + } |
70 | 180 | } |
0 commit comments