Skip to content

Commit 6afdcb6

Browse files
committed
Correct case analyzer
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 0acdd51 commit 6afdcb6

1 file changed

Lines changed: 125 additions & 98 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/request/CaseRangeAnalyzer.java

Lines changed: 125 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55

66
package org.opensearch.sql.opensearch.request;
77

8-
import java.math.BigDecimal;
9-
import java.util.List;
10-
import java.util.Objects;
11-
import java.util.Optional;
12-
138
import com.google.common.collect.BoundType;
149
import com.google.common.collect.Range;
1510
import com.google.common.collect.RangeSet;
1611
import com.google.common.collect.TreeRangeSet;
17-
import lombok.RequiredArgsConstructor;
12+
import java.math.BigDecimal;
13+
import java.util.List;
14+
import java.util.Objects;
15+
import java.util.Optional;
1816
import org.apache.calcite.rel.type.RelDataType;
1917
import org.apache.calcite.rex.RexCall;
2018
import org.apache.calcite.rex.RexInputRef;
@@ -26,26 +24,33 @@
2624
import org.apache.calcite.util.Sarg;
2725
import org.opensearch.search.aggregations.AggregationBuilders;
2826
import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder;
29-
import org.opensearch.search.aggregations.bucket.range.RangeAggregator;
3027

3128
/**
3229
* Analyzer to detect CASE expressions that can be converted to OpenSearch range aggregations.
3330
*
3431
* <p>Strict validation rules:
32+
*
3533
* <ul>
36-
* <li>All conditions must compare the same field with literals</li>
37-
* <li>Only simple comparison operators (>, >=, <, <=) are allowed</li>
38-
* <li>Ranges must be non-overlapping and contiguous</li>
39-
* <li>Return values must be string literals</li>
34+
* <li>All conditions must compare the same field with literals
35+
* <li>Only closed-open, at-least, and less-than ranges are allowed
36+
* <li>Return values must be string literals
4037
* </ul>
4138
*/
4239
public class CaseRangeAnalyzer {
40+
/** The default key to use if there isn't a key specified for the else case */
41+
public static final String DEFAULT_ELSE_KEY = "null";
42+
43+
/** The name for the range aggregation */
44+
public static final String NAME = "case_range";
45+
4346
private final RelDataType rowType;
44-
private final RangeSet<BigDecimal> rangeSet;
47+
private final RangeSet<Double> takenRange;
48+
private final RangeAggregationBuilder builder;
4549

4650
public CaseRangeAnalyzer(RelDataType rowType) {
4751
this.rowType = rowType;
48-
this.rangeSet = TreeRangeSet.create();
52+
this.takenRange = TreeRangeSet.create();
53+
this.builder = AggregationBuilders.range(NAME);
4954
}
5055

5156
/**
@@ -70,28 +75,34 @@ public Optional<RangeAggregationBuilder> analyze(RexCall caseCall) {
7075
}
7176

7277
List<RexNode> operands = caseCall.getOperands();
73-
RangeAggregationBuilder aggregationBuilder = AggregationBuilders.range("case_range");
7478

7579
// Process WHEN-THEN pairs
7680
for (int i = 0; i < operands.size() - 1; i += 2) {
7781
RexNode condition = operands.get(i);
78-
RexNode result = operands.get(i + 1);
82+
RexNode expr = operands.get(i + 1);
7983
// Result must be a literal
80-
if (!(result instanceof RexLiteral)) {
84+
if (!(expr instanceof RexLiteral)) {
8185
return Optional.empty();
8286
}
83-
String key = ((RexLiteral) result).getValueAs(String.class);
84-
analyzeCondition(aggregationBuilder, condition, key);
87+
String key = ((RexLiteral) expr).getValueAs(String.class);
88+
analyzeCondition(condition, key);
8589
}
8690

8791
// Check ELSE clause
88-
// TODO: Currently, we ignore else clause
89-
// Process the case without else clause and check range completeness later
90-
return Optional.of(aggregationBuilder);
92+
RexNode elseExpr = operands.getLast();
93+
String elseKey;
94+
if (RexLiteral.isNullLiteral(elseExpr)) {
95+
// range key doesn't support values of type: VALUE_NULL
96+
elseKey = DEFAULT_ELSE_KEY;
97+
} else {
98+
elseKey = ((RexLiteral) elseExpr).getValueAs(String.class);
99+
}
100+
addRangeSet(elseKey, takenRange.complement());
101+
return Optional.of(builder);
91102
}
92103

93104
/** Analyzes a single condition in the CASE WHEN clause. */
94-
private void analyzeCondition(RangeAggregationBuilder builder, RexNode condition, String key) {
105+
private void analyzeCondition(RexNode condition, String key) {
95106
if (!(condition instanceof RexCall)) {
96107
throwUnsupported("condition must be a RexCall");
97108
}
@@ -100,28 +111,35 @@ private void analyzeCondition(RangeAggregationBuilder builder, RexNode condition
100111
SqlKind kind = call.getKind();
101112

102113
// Handle simple comparisons
103-
if (kind == SqlKind.GREATER_THAN_OR_EQUAL || kind == SqlKind.LESS_THAN || kind == SqlKind.LESS_THAN_OR_EQUAL || kind == SqlKind.GREATER_THAN) {
104-
builder.addRange(analyzeSimpleComparison(builder, call, key));
114+
if (kind == SqlKind.GREATER_THAN_OR_EQUAL
115+
|| kind == SqlKind.LESS_THAN
116+
|| kind == SqlKind.LESS_THAN_OR_EQUAL
117+
|| kind == SqlKind.GREATER_THAN) {
118+
analyzeSimpleComparison(call, key);
119+
} else if (kind == SqlKind.SEARCH) {
120+
analyzeSearchCondition(call, key);
105121
}
106-
// Handle AND conditions (for range conditions like x >= 10 AND x < 100)
122+
// AND / OR will only appear when users try to create a complex condition on multiple fields
123+
// E.g. (a > 3 and b < 5). Otherwise, the complex conditions will be converted to a SEARCH call.
107124
else if (kind == SqlKind.AND || kind == SqlKind.OR) {
108-
analyzeCompositeCondition(builder, call, key);
109-
} else if (kind == SqlKind.SEARCH) {
110-
analyzeSearchCondition(builder, call, key);
125+
throwUnsupported("Range queries must be performed on the same field");
126+
} else {
127+
throwUnsupported("Can not analyze condition as a range query: " + call);
111128
}
112129
}
113130

114-
private RangeAggregator.Range analyzeSimpleComparison(RangeAggregationBuilder builder, RexCall call, String key) {
131+
private void analyzeSimpleComparison(RexCall call, String key) {
115132
List<RexNode> operands = call.getOperands();
116133
if (operands.size() != 2 || !(call.getOperator() instanceof SqlBinaryOperator)) {
117134
throwUnsupported();
118135
}
119136
RexNode left = operands.get(0);
120137
RexNode right = operands.get(1);
121-
SqlOperator operator = call.getOperator();
138+
SqlOperator operator = call.getOperator();
122139
RexInputRef inputRef = null;
123140
RexLiteral literal = null;
124141

142+
// Swap inputRef to the left if necessary
125143
if (left instanceof RexInputRef && right instanceof RexLiteral) {
126144
inputRef = (RexInputRef) left;
127145
literal = (RexLiteral) right;
@@ -145,103 +163,112 @@ private RangeAggregator.Range analyzeSimpleComparison(RangeAggregationBuilder bu
145163
}
146164

147165
Double value = literal.getValueAs(Double.class);
148-
return switch (operator.getKind()) {
149-
case GREATER_THAN_OR_EQUAL -> new RangeAggregator.Range(key, value, null);
150-
case LESS_THAN -> new RangeAggregator.Range(key, null, value);
151-
default -> throw new UnsupportedOperationException("ranges must equivalents of field >= constant or field < constant");
152-
};
153-
}
154-
155-
private void analyzeCompositeCondition(RangeAggregationBuilder builder, RexCall compositeCall, String key) {
156-
RexNode left = compositeCall.getOperands().get(0);
157-
RexNode right = compositeCall.getOperands().get(1);
158-
159-
if (!(left instanceof RexCall && right instanceof RexCall && ((RexCall) left).getOperator() instanceof SqlBinaryOperator && ((RexCall) right).getOperator() instanceof SqlBinaryOperator)) {
160-
throwUnsupported("cannot analyze deep nested comparison");
166+
if (value == null) {
167+
throwUnsupported("Cannot parse value for comparison");
161168
}
162-
163-
// For AND conditions, we need to analyze them separately and combine
164-
// Create temporary ranges to analyze the conditions
165-
RangeAggregator.Range leftRange = analyzeSimpleComparison(builder, (RexCall) left, key);
166-
RangeAggregator.Range rightRange = analyzeSimpleComparison(builder, (RexCall) right, key);
167-
168-
// Combine into single range
169-
if (compositeCall.getKind() == SqlKind.AND) {
170-
and(builder, leftRange, rightRange, key);
171-
} else if (compositeCall.getKind() == SqlKind.OR) {
172-
or(builder, leftRange, rightRange, key);
169+
switch (operator.getKind()) {
170+
case GREATER_THAN_OR_EQUAL -> {
171+
addFrom(key, value);
172+
}
173+
case LESS_THAN -> {
174+
addTo(key, value);
175+
}
176+
default -> throw new UnsupportedOperationException(
177+
"ranges must equivalents of field >= constant or field < constant");
173178
}
179+
;
174180
}
175181

176-
private void analyzeSearchCondition(RangeAggregationBuilder builder, RexCall searchCall, String key) {
182+
private void analyzeSearchCondition(RexCall searchCall, String key) {
177183
RexNode field = searchCall.getOperands().getFirst();
178-
if (!(field instanceof RexInputRef) || !Objects.equals(getFieldName((RexInputRef) field), builder.field())) {
184+
if (!(field instanceof RexInputRef)) {
185+
throwUnsupported("Range query must be performed on a field");
186+
}
187+
String fieldName = getFieldName((RexInputRef) field);
188+
if (builder.field() == null) {
189+
builder.field(fieldName);
190+
} else if (!Objects.equals(builder.field(), fieldName)) {
179191
throwUnsupported("Range query must be performed on the same field");
180192
}
181193
RexLiteral literal = (RexLiteral) searchCall.getOperands().getLast();
182194
Sarg<?> sarg = literal.getValueAs(Sarg.class);
183-
for(Object r: sarg.rangeSet.asRanges()){
195+
for (Object r : sarg.rangeSet.asRanges()) {
184196
Range<BigDecimal> range = (Range<BigDecimal>) r;
185-
if ((range.hasLowerBound() && range.lowerBoundType() != BoundType.CLOSED) || (range.hasUpperBound() && range.upperBoundType() != BoundType.OPEN)){
186-
throwUnsupported("Range query only supports closed-open ranges");
187-
}
197+
validateRange(range);
188198
if (!range.hasLowerBound() && range.hasUpperBound()) {
189-
builder.addUnboundedTo(key, range.upperEndpoint().doubleValue());
199+
// It will be Double.MAX_VALUE if be big decimal is greater than Double.MAX_VALUE
200+
double upper = range.upperEndpoint().doubleValue();
201+
addTo(key, upper);
190202
} else if (range.hasLowerBound() && !range.hasUpperBound()) {
191-
builder.addUnboundedFrom(key, range.lowerEndpoint().doubleValue());
203+
double lower = range.lowerEndpoint().doubleValue();
204+
addFrom(key, lower);
192205
} else if (range.hasLowerBound()) {
193-
builder.addRange(key, range.lowerEndpoint().doubleValue(), range.upperEndpoint().doubleValue());
206+
double lower = range.lowerEndpoint().doubleValue();
207+
double upper = range.upperEndpoint().doubleValue();
208+
addBetween(key, lower, upper);
194209
} else {
195-
builder.addRange(key, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
210+
addBetween(key, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
196211
}
197212
}
198213
}
199214

200-
private static void and(
201-
RangeAggregationBuilder builder, RangeAggregator.Range left, RangeAggregator.Range right, String key) {
202-
double mergedFrom = Math.max(left.getFrom(), right.getFrom());
203-
double mergedTo = Math.min(left.getTo(), right.getTo());
204-
if (mergedFrom > Double.NEGATIVE_INFINITY && mergedTo < Double.POSITIVE_INFINITY) {
205-
// Closed range: both bounds are finite
206-
builder.addRange(key, mergedFrom, mergedTo);
207-
} else if (mergedFrom > Double.NEGATIVE_INFINITY) {
208-
// Unbounded from: only lower bound (e.g., x >= 10)
209-
builder.addUnboundedFrom(key, mergedFrom);
210-
} else if (mergedTo < Double.POSITIVE_INFINITY) {
211-
// Unbounded to: only upper bound (e.g., x < 50)
212-
builder.addUnboundedTo(key, mergedTo);
213-
} // If no overlapping, do nothing
214-
}
215-
216-
private static void or(
217-
RangeAggregationBuilder builder, RangeAggregator.Range left, RangeAggregator.Range right, String key) {
218-
// sort left and right by swapping if necessary
219-
if(right.getFrom() < left.getFrom() || (left.getFrom() == right.getFrom() && right.getTo() < left.getTo())) {
220-
var tmp = right;
221-
right = left;
222-
left = tmp;
223-
}
224-
boolean overlap = left.getTo() > right.getFrom();
225-
if (overlap) {
226-
// Ranges overlap, meaning they cover all ranges - add both unbounded ranges
227-
double mergedFrom = Math.min(left.getFrom(), right.getFrom());
228-
double mergedTo = Math.max(left.getTo(), right.getTo());
229-
builder.addRange(key, mergedFrom, mergedTo);
215+
private void addFrom(String key, Double value) {
216+
var from = Range.atLeast(value);
217+
updateRange(key, from);
218+
}
219+
220+
private void addTo(String key, Double value) {
221+
var to = Range.lessThan(value);
222+
updateRange(key, to);
223+
}
224+
225+
private void addBetween(String key, Double from, Double to) {
226+
var range = Range.closedOpen(from, to);
227+
updateRange(key, range);
228+
}
229+
230+
private void updateRange(String key, Range<Double> range) {
231+
// The range to add: remaining space ∩ new range
232+
RangeSet<Double> toAdd = takenRange.complement().subRangeSet(range);
233+
addRangeSet(key, toAdd);
234+
takenRange.add(range);
235+
}
236+
237+
// Add range set without updating taken range
238+
private void addRangeSet(String key, RangeSet<Double> rangeSet) {
239+
rangeSet.asRanges().forEach(range -> addRange(key, range));
240+
}
241+
242+
// Add range without updating taken range
243+
private void addRange(String key, Range<Double> range) {
244+
validateRange(range);
245+
if (range.hasLowerBound() && range.hasUpperBound()) {
246+
builder.addRange(key, range.lowerEndpoint(), range.upperEndpoint());
247+
} else if (range.hasLowerBound()) {
248+
builder.addUnboundedFrom(key, range.lowerEndpoint());
249+
} else if (range.hasUpperBound()) {
250+
builder.addUnboundedTo(key, range.upperEndpoint());
230251
} else {
231-
builder.addRange(left);
232-
builder.addRange(right);
252+
builder.addRange(key, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
233253
}
234254
}
235255

236256
private String getFieldName(RexInputRef field) {
237257
return rowType.getFieldNames().get(field.getIndex());
238258
}
239259

260+
private static void validateRange(Range<?> range) {
261+
if ((range.hasLowerBound() && range.lowerBoundType() != BoundType.CLOSED)
262+
|| (range.hasUpperBound() && range.upperBoundType() != BoundType.OPEN)) {
263+
throwUnsupported("Range query only supports closed-open ranges");
264+
}
265+
}
266+
240267
private static void throwUnsupported() {
241-
throw new UnsupportedOperationException("Cannot create range aggregator");
268+
throw new UnsupportedOperationException("Cannot create range aggregator from case");
242269
}
243270

244271
private static void throwUnsupported(String message) {
245272
throw new UnsupportedOperationException("Cannot create range aggregator: " + message);
246273
}
247-
}
274+
}

0 commit comments

Comments
 (0)