Skip to content

Commit f416dee

Browse files
committed
Create bucket aggregation parsers that supports parsing nested sub aggregations
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 6afdcb6 commit f416dee

7 files changed

Lines changed: 176 additions & 150 deletions

File tree

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLCaseFunctionIT.java

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
package org.opensearch.sql.calcite.remote;
77

8-
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WEBLOGS;
98
import static org.junit.jupiter.api.Assertions.assertEquals;
109
import static org.junit.jupiter.api.Assertions.assertTrue;
10+
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WEBLOGS;
1111
import static org.opensearch.sql.util.MatcherUtils.rows;
1212
import static org.opensearch.sql.util.MatcherUtils.schema;
1313
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
@@ -264,10 +264,7 @@ public void testCaseRangeAggregationPushdown() throws IOException {
264264
+ ") | stats count() as total by range_bucket | sort range_bucket",
265265
TEST_INDEX_WEBLOGS));
266266

267-
verifySchema(
268-
actual,
269-
schema("range_bucket", "string"),
270-
schema("total", "long"));
267+
verifySchema(actual, schema("range_bucket", "string"), schema("total", "long"));
271268

272269
// This should work but won't be optimized due to implicit NULL bucket
273270
assertTrue(actual.getJSONArray("datarows").length() > 0);
@@ -279,12 +276,10 @@ public void testCaseRangeAggregationWithMetrics() throws IOException {
279276
JSONObject actual =
280277
executeQuery(
281278
String.format(
282-
"source=%s | eval size_category = case("
283-
+ " cast(bytes as int) < 2000, 'small',"
284-
+ " cast(bytes as int) >= 2000 AND cast(bytes as int) < 5000, 'medium',"
285-
+ " cast(bytes as int) >= 5000, 'large'"
286-
+ ") | stats count() as total, avg(cast(bytes as int)) as avg_bytes by size_category"
287-
+ " | sort size_category",
279+
"source=%s | eval size_category = case( cast(bytes as int) < 2000, 'small', "
280+
+ " cast(bytes as int) >= 2000 AND cast(bytes as int) < 5000, 'medium', "
281+
+ " cast(bytes as int) >= 5000, 'large') | stats count() as total,"
282+
+ " avg(cast(bytes as int)) as avg_bytes by size_category | sort size_category",
288283
TEST_INDEX_WEBLOGS));
289284

290285
verifySchema(
@@ -304,19 +299,14 @@ public void testCaseRangeAggregationWithElse() throws IOException {
304299
JSONObject actual =
305300
executeQuery(
306301
String.format(
307-
"source=%s | eval status_category = case("
308-
+ " cast(response as int) < 300, 'success',"
309-
+ " cast(response as int) >= 300 AND cast(response as int) < 400, 'redirect',"
310-
+ " cast(response as int) >= 400 AND cast(response as int) < 500, 'client_error',"
311-
+ " cast(response as int) >= 500, 'server_error'"
312-
+ " else 'unknown'"
313-
+ ") | stats count() by status_category | sort status_category",
302+
"source=%s | eval status_category = case( cast(response as int) < 300, 'success', "
303+
+ " cast(response as int) >= 300 AND cast(response as int) < 400, 'redirect', "
304+
+ " cast(response as int) >= 400 AND cast(response as int) < 500,"
305+
+ " 'client_error', cast(response as int) >= 500, 'server_error' else"
306+
+ " 'unknown') | stats count() by status_category | sort status_category",
314307
TEST_INDEX_WEBLOGS));
315308

316-
verifySchema(
317-
actual,
318-
schema("status_category", "string"),
319-
schema("count()", "long"));
309+
verifySchema(actual, schema("status_category", "string"), schema("count()", "long"));
320310

321311
// Should handle the ELSE case for null/non-numeric responses
322312
assertTrue(actual.getJSONArray("datarows").length() > 0);
@@ -335,10 +325,7 @@ public void testNonOptimizableCaseExpression() throws IOException {
335325
+ ") | stats count() by mixed_condition",
336326
TEST_INDEX_WEBLOGS));
337327

338-
verifySchema(
339-
actual,
340-
schema("mixed_condition", "string"),
341-
schema("count()", "long"));
328+
verifySchema(actual, schema("mixed_condition", "string"), schema("count()", "long"));
342329

343330
// This should work but won't be optimized
344331
assertTrue(actual.getJSONArray("datarows").length() > 0);
@@ -356,10 +343,7 @@ public void testCaseWithNonLiteralResult() throws IOException {
356343
+ ") | stats count() by computed_result | head 3",
357344
TEST_INDEX_WEBLOGS));
358345

359-
verifySchema(
360-
actual,
361-
schema("computed_result", "string"),
362-
schema("count()", "long"));
346+
verifySchema(actual, schema("computed_result", "string"), schema("count()", "long"));
363347

364348
// This should work but won't be optimized to range aggregation
365349
assertTrue(actual.getJSONArray("datarows").length() > 0);
@@ -379,10 +363,7 @@ public void testOptimizableCaseRangeAggregation() throws IOException {
379363
+ ") | stats count() by size_bucket | sort size_bucket",
380364
TEST_INDEX_WEBLOGS));
381365

382-
verifySchema(
383-
actual,
384-
schema("size_bucket", "string"),
385-
schema("count()", "long"));
366+
verifySchema(actual, schema("size_bucket", "string"), schema("count()", "long"));
386367

387368
// This should work - the explicit ELSE makes it potentially optimizable
388369
assertTrue(actual.getJSONArray("datarows").length() > 0);

opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import java.util.Set;
4242
import java.util.function.Function;
4343
import java.util.stream.Collectors;
44-
4544
import lombok.RequiredArgsConstructor;
4645
import org.apache.calcite.plan.RelOptCluster;
4746
import org.apache.calcite.rel.core.Aggregate;
@@ -60,10 +59,10 @@
6059
import org.opensearch.search.aggregations.AggregatorFactories;
6160
import org.opensearch.search.aggregations.AggregatorFactories.Builder;
6261
import org.opensearch.search.aggregations.BucketOrder;
63-
import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder;
6462
import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder;
6563
import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder;
6664
import org.opensearch.search.aggregations.bucket.missing.MissingOrder;
65+
import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder;
6766
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
6867
import org.opensearch.search.aggregations.metrics.ExtendedStats;
6968
import org.opensearch.search.aggregations.metrics.PercentilesAggregationBuilder;
@@ -81,11 +80,10 @@
8180
import org.opensearch.sql.opensearch.request.PredicateAnalyzer.NamedFieldExpression;
8281
import org.opensearch.sql.opensearch.response.agg.ArgMaxMinParser;
8382
import org.opensearch.sql.opensearch.response.agg.BucketAggregationParser;
84-
import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser;
83+
import org.opensearch.sql.opensearch.response.agg.LeafBucketAggregationParser;
8584
import org.opensearch.sql.opensearch.response.agg.MetricParser;
8685
import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser;
8786
import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser;
88-
import org.opensearch.sql.opensearch.response.agg.RangeParser;
8987
import org.opensearch.sql.opensearch.response.agg.SinglePercentileParser;
9088
import org.opensearch.sql.opensearch.response.agg.SingleValueParser;
9189
import org.opensearch.sql.opensearch.response.agg.StatsParser;
@@ -204,11 +202,19 @@ public static Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser
204202
Builder metricBuilder = builderAndParser.getLeft();
205203
List<MetricParser> metricParsers = builderAndParser.getRight();
206204
// Find group by fields derived from CASE functions and convert them to range queries
207-
CaseRangeAnalyzer rangeAnalyzer = CaseRangeAnalyzer.create(rowType);
208-
List<Pair<Integer, RangeAggregationBuilder>> groupsByCase = groupList.stream()
209-
.map(i -> Pair.of(i, project.getProjects().get(i)))
210-
.filter(p -> p.getRight() instanceof RexCall rexCall && rexCall.getKind() == SqlKind.CASE)
211-
.map(p -> Pair.of(p.getLeft(), rangeAnalyzer.analyze((RexCall) p.getRight())))
205+
List<Pair<Integer, RangeAggregationBuilder>> groupsByCase =
206+
groupList.stream()
207+
.map(i -> Pair.of(i, project.getNamedProjects().get(i)))
208+
.filter(
209+
p ->
210+
p.getRight().getKey() instanceof RexCall rexCall
211+
&& rexCall.getKind() == SqlKind.CASE)
212+
.map(
213+
p ->
214+
Pair.of(
215+
p.getLeft(),
216+
CaseRangeAnalyzer.create(p.getRight().getValue(), rowType)
217+
.analyze((RexCall) p.getRight().getKey())))
212218
.filter(p -> p.getRight().isPresent())
213219
.map(p -> Pair.of(p.getLeft(), p.getRight().get()))
214220
.toList();
@@ -220,28 +226,32 @@ public static Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser
220226
// Note that but a composite aggregation can not be a sub aggregation of range aggregation,
221227
// but range aggregation can be a sub aggregation of a composite aggregation.
222228
AggregationBuilder rangeAggregationBuilder = null;
229+
BucketAggregationParser bucketAggregationParser = null;
223230
if (!groupsByCase.isEmpty()) {
224231
for (int i = 0; i < groupsByCase.size(); i++) {
225232
Pair<Integer, RangeAggregationBuilder> pair = groupsByCase.get(i);
226233
if (i == 0) {
227234
rangeAggregationBuilder = pair.getRight();
235+
bucketAggregationParser =
236+
new BucketAggregationParser(new LeafBucketAggregationParser(metricParsers));
228237
} else {
229238
groupsByCase.get(i - 1).getRight().subAggregation(pair.getRight());
239+
bucketAggregationParser = new BucketAggregationParser(bucketAggregationParser);
230240
}
231241
}
232242
groupsByCase.getLast().getRight().subAggregations(metricBuilder);
233-
metricParsers.add(new RangeParser("case_range"));
234243
}
235244

236245
// Remove groups that are converted to ranges from groupList
237246
Set<Integer> toRemove = groupsByCase.stream().map(Pair::getLeft).collect(Collectors.toSet());
238-
List<Integer> filteredGroupList = groupList.stream().filter(i -> !toRemove.contains(i)).toList();
247+
List<Integer> filteredGroupList =
248+
groupList.stream().filter(i -> !toRemove.contains(i)).toList();
239249

240250
// The top-level query is a range query: stats count() by range_field
241251
// RangeAgg
242252
// Metric
243253
if (!groupsByCase.isEmpty() && filteredGroupList.isEmpty()) {
244-
return Pair.of(List.of(rangeAggregationBuilder), new BucketAggregationParser(metricParsers));
254+
return Pair.of(List.of(rangeAggregationBuilder), bucketAggregationParser);
245255
}
246256
// No parent composite aggregation or range aggregation is attached: stats count()
247257
// Metric
@@ -250,21 +260,23 @@ else if (aggregate.getGroupSet().isEmpty() && filteredGroupList.isEmpty()) {
250260
ImmutableList.copyOf(metricBuilder.getAggregatorFactories()),
251261
new NoBucketAggregationParser(metricParsers));
252262
}
253-
// It has both composite aggregation and range aggregation: stats count() by range_field, non_range_field
263+
// It has both composite aggregation and range aggregation: stats count() by range_field,
264+
// non_range_field
254265
// CompositeAgg
255266
// RangeAgg
256267
// Metric
257268
else if (!groupsByCase.isEmpty()) {
258269
List<CompositeValuesSourceBuilder<?>> buckets =
259-
createCompositeBuckets(filteredGroupList, project, helper);
270+
createCompositeBuckets(filteredGroupList, project, helper);
260271
return Pair.of(
261-
Collections.singletonList(
262-
AggregationBuilders.composite("composite_buckets", buckets)
263-
.subAggregation(rangeAggregationBuilder)
264-
.size(AGGREGATION_BUCKET_SIZE)),
265-
new CompositeAggregationParser(metricParsers));
272+
Collections.singletonList(
273+
AggregationBuilders.composite("composite_buckets", buckets)
274+
.subAggregation(rangeAggregationBuilder)
275+
.size(AGGREGATION_BUCKET_SIZE)),
276+
new BucketAggregationParser(bucketAggregationParser));
266277
}
267-
// It does not have range aggregation, but has composite aggregation: stats count() by non_range_field
278+
// It does not have range aggregation, but has composite aggregation: stats count() by
279+
// non_range_field
268280
// CompositeAgg
269281
// Metric
270282
else {
@@ -275,7 +287,7 @@ else if (!groupsByCase.isEmpty()) {
275287
AggregationBuilders.composite("composite_buckets", buckets)
276288
.subAggregations(metricBuilder)
277289
.size(AGGREGATION_BUCKET_SIZE)),
278-
new CompositeAggregationParser(metricParsers));
290+
new BucketAggregationParser(new LeafBucketAggregationParser(metricParsers)));
279291
}
280292
} catch (Throwable e) {
281293
Throwables.throwIfInstanceOf(e, UnsupportedOperationException.class);

opensearch/src/main/java/org/opensearch/sql/opensearch/request/CaseRangeAnalyzer.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,14 @@ public class CaseRangeAnalyzer {
4040
/** The default key to use if there isn't a key specified for the else case */
4141
public static final String DEFAULT_ELSE_KEY = "null";
4242

43-
/** The name for the range aggregation */
44-
public static final String NAME = "case_range";
45-
4643
private final RelDataType rowType;
4744
private final RangeSet<Double> takenRange;
4845
private final RangeAggregationBuilder builder;
4946

50-
public CaseRangeAnalyzer(RelDataType rowType) {
47+
public CaseRangeAnalyzer(String name, RelDataType rowType) {
5148
this.rowType = rowType;
5249
this.takenRange = TreeRangeSet.create();
53-
this.builder = AggregationBuilders.range(NAME);
50+
this.builder = AggregationBuilders.range(name).keyed(true);
5451
}
5552

5653
/**
@@ -59,8 +56,8 @@ public CaseRangeAnalyzer(RelDataType rowType) {
5956
* @param rowType the row type information for field resolution
6057
* @return a new CaseRangeAnalyzer instance
6158
*/
62-
public static CaseRangeAnalyzer create(RelDataType rowType) {
63-
return new CaseRangeAnalyzer(rowType);
59+
public static CaseRangeAnalyzer create(String name, RelDataType rowType) {
60+
return new CaseRangeAnalyzer(name, rowType);
6461
}
6562

6663
/**

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/BucketAggregationParser.java

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,67 @@
55

66
package org.opensearch.sql.opensearch.response.agg;
77

8-
import java.util.Arrays;
9-
import java.util.LinkedHashMap;
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
1010
import java.util.List;
1111
import java.util.Map;
1212
import java.util.stream.Collectors;
13-
import lombok.EqualsAndHashCode;
1413
import org.opensearch.search.aggregations.Aggregation;
1514
import org.opensearch.search.aggregations.Aggregations;
1615
import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation;
16+
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation;
1717

18-
/**
19-
* Use BucketAggregationParser only when there is a single group-by key, it returns multiple
20-
* buckets. {@link CompositeAggregationParser} is used for multiple group by keys
21-
*/
22-
@EqualsAndHashCode
2318
public class BucketAggregationParser implements OpenSearchAggregationResponseParser {
24-
private final MetricParserHelper metricsParser;
25-
26-
public BucketAggregationParser(MetricParser... metricParserList) {
27-
metricsParser = new MetricParserHelper(Arrays.asList(metricParserList));
28-
}
19+
private final OpenSearchAggregationResponseParser subAggParser;
2920

30-
public BucketAggregationParser(List<MetricParser> metricParserList) {
31-
metricsParser = new MetricParserHelper(metricParserList);
21+
public BucketAggregationParser(OpenSearchAggregationResponseParser subAggParser) {
22+
this.subAggParser = subAggParser;
3223
}
3324

3425
@Override
3526
public List<Map<String, Object>> parse(Aggregations aggregations) {
36-
Aggregation agg = aggregations.asList().getFirst();
37-
return ((MultiBucketsAggregation) agg)
38-
.getBuckets().stream().map(b -> parse(b, agg.getName())).collect(Collectors.toList());
27+
if (subAggParser instanceof BucketAggregationParser) {
28+
return aggregations.asList().stream()
29+
.map(
30+
aggregation -> {
31+
if (aggregation instanceof CompositeAggregation) {
32+
return (CompositeAggregation) aggregation;
33+
} else {
34+
return (MultiBucketsAggregation) aggregation;
35+
}
36+
})
37+
.map(MultiBucketsAggregation::getBuckets)
38+
.flatMap(List::stream)
39+
.map(this::parse)
40+
.flatMap(List::stream)
41+
.collect(Collectors.toList());
42+
} else if (subAggParser instanceof LeafBucketAggregationParser) {
43+
return subAggParser.parse(aggregations);
44+
} else {
45+
throw new IllegalStateException(
46+
"Sub parsers of a BucketAggregationParser can only be either BucketAggregationParser or"
47+
+ " LeafBucketAggregationParser");
48+
}
49+
}
50+
51+
private List<Map<String, Object>> parse(MultiBucketsAggregation.Bucket bucket) {
52+
if (bucket instanceof CompositeAggregation.Bucket compositeBucket) {
53+
return parse(compositeBucket);
54+
}
55+
List<Map<String, Object>> results = new ArrayList<>();
56+
for (Aggregation subAgg : bucket.getAggregations()) {
57+
var sub = (Aggregations) subAgg;
58+
results.addAll(subAggParser.parse(sub));
59+
}
60+
return results;
3961
}
4062

41-
private Map<String, Object> parse(MultiBucketsAggregation.Bucket bucket, String keyName) {
42-
Map<String, Object> resultMap = new LinkedHashMap<>();
43-
resultMap.put(keyName, bucket.getKey());
44-
resultMap.putAll(metricsParser.parse(bucket.getAggregations()));
45-
return resultMap;
63+
private List<Map<String, Object>> parse(CompositeAggregation.Bucket bucket) {
64+
Map<String, Object> common = new HashMap<>(bucket.getKey());
65+
List<Map<String, Object>> results = subAggParser.parse(bucket.getAggregations());
66+
for (Map<String, Object> r : results) {
67+
r.putAll(common);
68+
}
69+
return results;
4670
}
4771
}

0 commit comments

Comments
 (0)