|
5 | 5 |
|
6 | 6 | package org.opensearch.sql.calcite.utils; |
7 | 7 |
|
8 | | -import java.util.ArrayList; |
9 | 8 | import java.util.LinkedHashSet; |
10 | 9 | import java.util.List; |
11 | 10 | import java.util.Set; |
12 | | -import org.apache.calcite.linq4j.Ord; |
13 | 11 | import org.apache.calcite.plan.RelOptUtil; |
14 | 12 | import org.apache.calcite.rel.RelNode; |
15 | | -import org.apache.calcite.rel.core.Aggregate; |
16 | | -import org.apache.calcite.rel.core.CorrelationId; |
17 | | -import org.apache.calcite.rel.core.Project; |
18 | | -import org.apache.calcite.rel.core.Values; |
19 | 13 | import org.apache.calcite.rel.type.RelDataType; |
20 | 14 | import org.apache.calcite.rel.type.RelDataTypeField; |
21 | | -import org.apache.calcite.rex.RexLiteral; |
22 | 15 | import org.apache.calcite.rex.RexNode; |
23 | 16 | import org.apache.calcite.rex.RexPermuteInputsShuttle; |
24 | | -import org.apache.calcite.rex.RexSubQuery; |
25 | | -import org.apache.calcite.rex.RexUtil; |
26 | 17 | import org.apache.calcite.rex.RexVisitor; |
27 | 18 | import org.apache.calcite.sql.validate.SqlValidator; |
28 | 19 | import org.apache.calcite.sql2rel.RelFieldTrimmer; |
29 | 20 | import org.apache.calcite.tools.RelBuilder; |
30 | 21 | import org.apache.calcite.util.ImmutableBitSet; |
31 | 22 | import org.apache.calcite.util.mapping.Mapping; |
32 | | -import org.apache.calcite.util.mapping.MappingType; |
33 | 23 | import org.apache.calcite.util.mapping.Mappings; |
34 | 24 | import org.checkerframework.checker.nullness.qual.Nullable; |
35 | 25 | import org.opensearch.sql.calcite.plan.rel.Dedup; |
|
40 | 30 | * <p>This class extends Calcite's RelFieldTrimmer to support trimming customized operators. |
41 | 31 | */ |
42 | 32 | public class OpenSearchRelFieldTrimmer extends RelFieldTrimmer { |
43 | | - private final RelBuilder openSearchRelBuilder; |
44 | 33 |
|
45 | 34 | public OpenSearchRelFieldTrimmer(@Nullable SqlValidator validator, RelBuilder relBuilder) { |
46 | 35 | super(validator, relBuilder); |
47 | | - this.openSearchRelBuilder = relBuilder; |
48 | | - } |
49 | | - |
50 | | - @Override |
51 | | - public TrimResult trimFields( |
52 | | - Project project, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) { |
53 | | - final RelDataType rowType = project.getRowType(); |
54 | | - final int fieldCount = rowType.getFieldCount(); |
55 | | - final RelNode input = project.getInput(); |
56 | | - |
57 | | - final Set<RelDataTypeField> inputExtraFields = new LinkedHashSet<>(extraFields); |
58 | | - RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields); |
59 | | - for (Ord<RexNode> ord : Ord.zip(project.getProjects())) { |
60 | | - if (fieldsUsed.get(ord.i)) { |
61 | | - ord.e.accept(inputFinder); |
62 | | - } |
63 | | - } |
64 | | - |
65 | | - List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project); |
66 | | - Set<CorrelationId> correlationIds = RelOptUtil.getVariablesUsed(subQueries); |
67 | | - ImmutableBitSet requiredColumns = ImmutableBitSet.of(); |
68 | | - if (!correlationIds.isEmpty()) { |
69 | | - assert correlationIds.size() == 1; |
70 | | - requiredColumns = RelOptUtil.correlationColumns(correlationIds.iterator().next(), project); |
71 | | - } |
72 | | - |
73 | | - ImmutableBitSet finderFields = inputFinder.build(); |
74 | | - ImmutableBitSet inputFieldsUsed = |
75 | | - ImmutableBitSet.builder().addAll(requiredColumns).addAll(finderFields).build(); |
76 | | - |
77 | | - TrimResult trimResult = trimChild(project, input, inputFieldsUsed, inputExtraFields); |
78 | | - RelNode newInput = trimResult.left; |
79 | | - final Mapping inputMapping = trimResult.right; |
80 | | - |
81 | | - if (newInput == input && fieldsUsed.cardinality() == fieldCount) { |
82 | | - return result(project, Mappings.createIdentity(fieldCount)); |
83 | | - } |
84 | | - |
85 | | - if (fieldsUsed.cardinality() == 0) { |
86 | | - return dummyProject(fieldCount, newInput, project); |
87 | | - } |
88 | | - |
89 | | - final List<RexNode> newProjects = new ArrayList<>(); |
90 | | - final RexVisitor<RexNode> shuttle; |
91 | | - if (!correlationIds.isEmpty()) { |
92 | | - assert correlationIds.size() == 1; |
93 | | - shuttle = |
94 | | - new RexPermuteInputsShuttle(inputMapping, newInput) { |
95 | | - @Override |
96 | | - public RexNode visitSubQuery(RexSubQuery subQuery) { |
97 | | - subQuery = (RexSubQuery) super.visitSubQuery(subQuery); |
98 | | - return RelOptUtil.remapCorrelatesInSuqQuery( |
99 | | - openSearchRelBuilder.getRexBuilder(), |
100 | | - subQuery, |
101 | | - correlationIds.iterator().next(), |
102 | | - newInput.getRowType(), |
103 | | - inputMapping); |
104 | | - } |
105 | | - }; |
106 | | - } else { |
107 | | - shuttle = new RexPermuteInputsShuttle(inputMapping, newInput); |
108 | | - } |
109 | | - |
110 | | - final Mapping mapping = |
111 | | - Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, fieldsUsed.cardinality()); |
112 | | - for (Ord<RexNode> ord : Ord.zip(project.getProjects())) { |
113 | | - if (fieldsUsed.get(ord.i)) { |
114 | | - mapping.set(ord.i, newProjects.size()); |
115 | | - RexNode newProjectExpr = ord.e.accept(shuttle); |
116 | | - newProjects.add(newProjectExpr); |
117 | | - } |
118 | | - } |
119 | | - |
120 | | - final RelDataType newRowType = |
121 | | - RelOptUtil.permute(project.getCluster().getTypeFactory(), rowType, mapping); |
122 | | - |
123 | | - if (shouldAvoidSimplifyValues(newProjects, newInput)) { |
124 | | - return result( |
125 | | - project.copy(project.getTraitSet(), newInput, newProjects, newRowType), mapping, project); |
126 | | - } |
127 | | - |
128 | | - openSearchRelBuilder.push(newInput); |
129 | | - openSearchRelBuilder.project(newProjects, newRowType.getFieldNames(), false, correlationIds); |
130 | | - return result(openSearchRelBuilder.build(), mapping, project); |
131 | 36 | } |
132 | 37 |
|
133 | 38 | public TrimResult trimFields( |
@@ -162,19 +67,4 @@ public TrimResult trimFields( |
162 | 67 | // needs them for its condition. |
163 | 68 | return result(dedup.copy(newInput, newDedupFields), inputMapping); |
164 | 69 | } |
165 | | - |
166 | | - private boolean shouldAvoidSimplifyValues(List<RexNode> projects, RelNode input) { |
167 | | - return projects.stream().allMatch(RexLiteral.class::isInstance) && isFixedRowCount(input); |
168 | | - } |
169 | | - |
170 | | - private boolean isFixedRowCount(RelNode input) { |
171 | | - if (input instanceof Values) { |
172 | | - return true; |
173 | | - } |
174 | | - if (input instanceof Aggregate aggregate) { |
175 | | - return aggregate.getGroupSet().isEmpty() |
176 | | - && aggregate.getGroupType() == Aggregate.Group.SIMPLE; |
177 | | - } |
178 | | - return false; |
179 | | - } |
180 | 70 | } |
0 commit comments