Skip to content

Commit 0ca81aa

Browse files
committed
Make case pushdown a private method
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 7a8db58 commit 0ca81aa

1 file changed

Lines changed: 63 additions & 31 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,6 @@ public static Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser
211211
Builder metricBuilder = builderAndParser.getLeft();
212212
List<MetricParser> metricParsers = builderAndParser.getRight();
213213

214-
List<Pair<Integer, RangeAggregationBuilder>> groupsByCase =
215-
analyzeCaseInProject(groupList, project, rowType);
216-
// Remove groups that are converted to ranges from groupList
217-
Set<Integer> toRemoveFromGroupList =
218-
groupsByCase.stream().map(Pair::getLeft).collect(Collectors.toSet());
219-
// The group-by list after removing CASE that can be converted to range queries
220-
groupList = groupList.stream().filter(i -> !toRemoveFromGroupList.contains(i)).toList();
221-
222214
// both count() and count(FIELD) can apply doc_count optimization in non-bucket aggregation,
223215
// but only count() can apply doc_count optimization in bucket aggregation.
224216
boolean countAllOnly = !groupList.isEmpty();
@@ -227,32 +219,21 @@ public static Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser
227219
Builder newMetricBuilder = countAggNameAndBuilderPair.getRight();
228220
List<String> countAggNames = countAggNameAndBuilderPair.getLeft();
229221

230-
// Cascade aggregations in such a way:
231-
// RangeAggregation
232-
// ...Any other range aggregations
233-
// Metric Aggregation comes at last
234-
// Note that but a composite aggregation can not be a sub aggregation of range aggregation,
235-
// but range aggregation can be a sub aggregation of a composite aggregation.
236-
AggregationBuilder rangeAggregationBuilder = null;
237-
if (!groupsByCase.isEmpty()) {
238-
for (int i = 0; i < groupsByCase.size(); i++) {
239-
Pair<Integer, RangeAggregationBuilder> p = groupsByCase.get(i);
240-
if (i == 0) {
241-
rangeAggregationBuilder = p.getRight();
242-
} else {
243-
groupsByCase.get(i - 1).getRight().subAggregation(p.getRight());
244-
}
245-
}
246-
groupsByCase.getLast().getRight().subAggregations(newMetricBuilder);
247-
}
222+
Pair<Set<Integer>, AggregationBuilder> caseAggPushedAndRangeBuilder =
223+
pushCaseAsRanges(groupList, project, rowType, newMetricBuilder);
224+
// Remove groups that are converted to ranges from groupList
225+
Set<Integer> aggPushedAsRanges = caseAggPushedAndRangeBuilder.getLeft();
226+
AggregationBuilder rangeAggregationBuilder = caseAggPushedAndRangeBuilder.getRight();
227+
// The group-by list after removing CASE that can be converted to range queries
228+
groupList = groupList.stream().filter(i -> !aggPushedAsRanges.contains(i)).toList();
248229

249230
// The top-level query is a range query:
250231
// - stats avg() by range_field
251232
// - stats count() by range_field
252233
// - stats avg(), count() by range_field
253234
// RangeAgg
254235
// Metric
255-
if (!groupsByCase.isEmpty() && groupList.isEmpty()) {
236+
if (!aggPushedAsRanges.isEmpty() && groupList.isEmpty()) {
256237
return Pair.of(
257238
List.of(rangeAggregationBuilder),
258239
new BucketAggregationParser(metricParsers, countAggNames));
@@ -292,7 +273,7 @@ && isAutoDateSpan(
292273
// CompositeAgg
293274
// RangeAgg
294275
// Metric
295-
else if (!groupsByCase.isEmpty()) {
276+
else if (!aggPushedAsRanges.isEmpty()) {
296277
List<CompositeValuesSourceBuilder<?>> buckets =
297278
createCompositeBuckets(groupList, project, helper);
298279
return Pair.of(
@@ -381,8 +362,37 @@ private static boolean supportCountFiled(
381362
== 1;
382363
}
383364

384-
private static List<Pair<Integer, RangeAggregationBuilder>> analyzeCaseInProject(
385-
List<Integer> groupList, Project project, RelDataType rowType) {
365+
/**
366+
* Analyzes and converts CASE expressions in GROUP BY clauses to OpenSearch range aggregations.
367+
*
368+
* <p>This method identifies group by fields that are derived from CASE functions and transforms
369+
* them into range aggregation builders. The resulting aggregations are cascaded in a hierarchical
370+
* structure where range aggregations contain other range aggregations as sub-aggregations, with
371+
* metric aggregations placed at the deepest level.
372+
*
373+
* <p>The aggregation hierarchy follows this pattern:
374+
*
375+
* <pre>
376+
* RangeAggregation
377+
* └── RangeAggregation (nested)
378+
* └── ... (more range aggregations)
379+
* └── Metric Aggregation (at the bottom)
380+
* </pre>
381+
*
382+
* @param groupList the list of group by field indices from the query
383+
* @param project the projection containing the expressions to analyze, may be null
384+
* @param rowType the data type information for the current row structure
385+
* @param metricBuilder the metric aggregation builder to be placed at the bottom of the hierarchy
386+
* @return a pair containing:
387+
* <ul>
388+
* <li>A set of integers representing the indices of group fields that were successfully
389+
* converted to range aggregations
390+
* <li>The root range aggregation builder, or null if no CASE expressions were found or
391+
* converted
392+
* </ul>
393+
*/
394+
private static Pair<Set<Integer>, AggregationBuilder> pushCaseAsRanges(
395+
List<Integer> groupList, Project project, RelDataType rowType, Builder metricBuilder) {
386396
// Find group by fields derived from CASE functions and convert them to range queries
387397
List<Pair<Integer, RangeAggregationBuilder>> groupsByCase =
388398
groupList.stream()
@@ -401,7 +411,29 @@ private static List<Pair<Integer, RangeAggregationBuilder>> analyzeCaseInProject
401411
.filter(p -> p.getRight().isPresent())
402412
.map(p -> Pair.of(p.getLeft(), p.getRight().get()))
403413
.toList();
404-
return groupsByCase;
414+
415+
// Cascade aggregations in such a way:
416+
// RangeAggregation
417+
// ...Any other range aggregations
418+
// Metric Aggregation comes at last
419+
// Note that but a composite aggregation can not be a sub aggregation of range aggregation,
420+
// but range aggregation can be a sub aggregation of a composite aggregation.
421+
AggregationBuilder rangeAggregationBuilder = null;
422+
if (!groupsByCase.isEmpty()) {
423+
for (int i = 0; i < groupsByCase.size(); i++) {
424+
Pair<Integer, RangeAggregationBuilder> p = groupsByCase.get(i);
425+
if (i == 0) {
426+
rangeAggregationBuilder = p.getRight();
427+
} else {
428+
groupsByCase.get(i - 1).getRight().subAggregation(p.getRight());
429+
}
430+
}
431+
groupsByCase.getLast().getRight().subAggregations(metricBuilder);
432+
}
433+
434+
Set<Integer> aggPushedAsRanges =
435+
groupsByCase.stream().map(Pair::getLeft).collect(Collectors.toSet());
436+
return Pair.of(aggPushedAsRanges, rangeAggregationBuilder);
405437
}
406438

407439
private static Pair<Builder, List<MetricParser>> processAggregateCalls(

0 commit comments

Comments
 (0)