Skip to content

Commit f690add

Browse files
committed
Add integration tests for case in aggregation
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 0493629 commit f690add

5 files changed

Lines changed: 198 additions & 103 deletions

File tree

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ public void testCasePushdownAsRangeQueryExplain() throws IOException {
944944
TEST_INDEX_BANK)));
945945

946946
// CASE 2: Composite - Range - Metric
947-
// 2.1 Composite(1 field) - Range - Metric
947+
// 2.1 Composite (term) - Range - Metric
948948
assertYamlEqualsJsonIgnoreId(
949949
loadExpectedPlan("agg_composite_range_metric_push.yaml"),
950950
explainQueryToString(
@@ -953,7 +953,15 @@ public void testCasePushdownAsRangeQueryExplain() throws IOException {
953953
+ " by state, age_range",
954954
TEST_INDEX_BANK)));
955955

956-
// 2.2 Composite(2 fields) - Range - Metric (with count)
956+
// 2.2 Composite (date histogram) - Range - Metric
957+
assertYamlEqualsJsonIgnoreId(
958+
loadExpectedPlan("agg_composite_date_range_push.yaml"),
959+
explainQueryToString(
960+
"source=opensearch-sql_test_index_time_data | eval value_range = case(value < 7000,"
961+
+ " 'small' else 'large') | stats avg(value) by value_range, span(@timestamp,"
962+
+ " 1h)"));
963+
964+
// 2.3 Composite(2 fields) - Range - Metric (with count)
957965
assertYamlEqualsJsonIgnoreId(
958966
loadExpectedPlan("agg_composite2_range_count_push.yaml"),
959967
explainQueryToString(
@@ -962,7 +970,7 @@ public void testCasePushdownAsRangeQueryExplain() throws IOException {
962970
+ " avg(balance), count() by age_range, state, gender",
963971
TEST_INDEX_BANK)));
964972

965-
// 2.3 Composite (2 fields) - Range - Range - Metric (with count)
973+
// 2.4 Composite (2 fields) - Range - Range - Metric (with count)
966974
assertYamlEqualsJsonIgnoreId(
967975
loadExpectedPlan("agg_composite2_range_range_count_push.yaml"),
968976
explainQueryToString(
@@ -972,7 +980,7 @@ public void testCasePushdownAsRangeQueryExplain() throws IOException {
972980
+ " avg_balance by age_range, balance_range, state",
973981
TEST_INDEX_BANK)));
974982

975-
// 2.4 Should not be pushed because case result expression is not constant
983+
// 2.5 Should not be pushed because case result expression is not constant
976984
assertYamlEqualsJsonIgnoreId(
977985
loadExpectedPlan("agg_case_composite_cannot_push.yaml"),
978986
explainQueryToString(

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLCaseFunctionIT.java

Lines changed: 157 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
package org.opensearch.sql.calcite.remote;
77

8-
import static org.junit.jupiter.api.Assertions.assertEquals;
98
import static org.junit.jupiter.api.Assertions.assertTrue;
9+
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
1010
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WEBLOGS;
1111
import static org.opensearch.sql.util.MatcherUtils.rows;
1212
import static org.opensearch.sql.util.MatcherUtils.schema;
@@ -27,6 +27,9 @@ public void init() throws Exception {
2727
enableCalcite();
2828

2929
loadIndex(Index.WEBLOG);
30+
loadIndex(Index.TIME_TEST_DATA);
31+
loadIndex(Index.TIME_TEST_DATA_WITH_NULL);
32+
loadIndex(Index.BANK);
3033
appendDataForBadResponse();
3134
}
3235

@@ -250,122 +253,180 @@ public void testCaseWhenInSubquery() throws IOException {
250253
}
251254

252255
@Test
253-
public void testCaseRangeAggregationPushdown() throws IOException {
254-
// Test CASE expression that can be optimized to range aggregation
255-
// Note: This has an implicit ELSE NULL, so it won't be optimized
256-
// But it should still work correctly
257-
JSONObject actual =
256+
public void testCaseCanBePushedDownAsRangeQuery() throws IOException {
257+
// CASE 1: Range - Metric
258+
// 1.1 Range - Metric
259+
JSONObject actual1 =
258260
executeQuery(
259261
String.format(
260-
"source=%s | eval range_bucket = case("
261-
+ " cast(bytes as int) < 1000, 'small',"
262-
+ " cast(bytes as int) >= 1000 AND cast(bytes as int) < 5000, 'medium',"
263-
+ " cast(bytes as int) >= 5000, 'large'"
264-
+ ") | stats count() as total by range_bucket | sort range_bucket",
265-
TEST_INDEX_WEBLOGS));
262+
"source=%s | eval age_range = case(age < 30, 'u30', age < 40, 'u40' else 'u100') |"
263+
+ " stats avg(age) as avg_age by age_range",
264+
TEST_INDEX_BANK));
265+
verifySchema(actual1, schema("avg_age", "double"), schema("age_range", "string"));
266+
verifyDataRows(actual1, rows(28.0, "u30"), rows(35.0, "u40"));
266267

267-
verifySchema(actual, schema("range_bucket", "string"), schema("total", "bigint"));
268-
269-
// This should work but won't be optimized due to implicit NULL bucket
270-
assertTrue(actual.getJSONArray("datarows").length() > 0);
271-
}
272-
273-
@Test
274-
public void testCaseRangeAggregationWithMetrics() throws IOException {
275-
// Test CASE-to-range with additional aggregations
276-
JSONObject actual =
268+
// 1.2 Range - Metric (COUNT)
269+
JSONObject actual2 =
277270
executeQuery(
278271
String.format(
279-
"source=%s | eval size_category = case( cast(bytes as int) < 2000, 'small', "
280-
+ " cast(bytes as int) >= 2000 AND cast(bytes as int) < 5000, 'medium', "
281-
+ " cast(bytes as int) >= 5000, 'large') | stats count() as total,"
282-
+ " avg(cast(bytes as int)) as avg_bytes by size_category | sort size_category",
283-
TEST_INDEX_WEBLOGS));
272+
"source=%s | eval age_range = case(age < 30, 'u30', age >= 30 and age < 40, 'u40'"
273+
+ " else 'u100') | stats avg(age) by age_range",
274+
TEST_INDEX_BANK));
275+
verifySchema(actual2, schema("avg(age)", "double"), schema("age_range", "string"));
276+
verifyDataRows(actual2, rows(28.0, "u30"), rows(35.0, "u40"));
284277

278+
// 1.3 Range - Range - Metric
279+
JSONObject actual3 =
280+
executeQuery(
281+
String.format(
282+
"source=%s | eval age_range = case(age < 30, 'u30', age < 40, 'u40' else 'u100'),"
283+
+ " balance_range = case(balance < 20000, 'medium' else 'high') | stats"
284+
+ " avg(balance) as avg_balance by age_range, balance_range",
285+
TEST_INDEX_BANK));
285286
verifySchema(
286-
actual,
287-
schema("size_category", "string"),
288-
schema("total", "bigint"),
289-
schema("avg_bytes", "double"));
290-
291-
// Verify we get results for each category
292-
// The exact values may vary based on test data, but structure should be correct
293-
assertEquals(3, actual.getJSONArray("datarows").length());
294-
}
287+
actual3,
288+
schema("avg_balance", "double"),
289+
schema("age_range", "string"),
290+
schema("balance_range", "string"));
291+
verifyDataRows(
292+
actual3,
293+
rows(32838.0, "u30", "high"),
294+
rows(8761.333333333334, "u40", "medium"),
295+
rows(42617.0, "u40", "high"));
295296

296-
@Test
297-
public void testCaseRangeAggregationWithElse() throws IOException {
298-
// Test CASE with explicit ELSE clause
299-
JSONObject actual =
297+
// 1.4 Range - Metric (With null & discontinuous ranges)
298+
JSONObject actual4 =
300299
executeQuery(
301300
String.format(
302-
"source=%s | eval status_category = case( cast(response as int) < 300, 'success', "
303-
+ " cast(response as int) >= 300 AND cast(response as int) < 400, 'redirect', "
304-
+ " cast(response as int) >= 400 AND cast(response as int) < 500,"
305-
+ " 'client_error', cast(response as int) >= 500, 'server_error' else"
306-
+ " 'unknown') | stats count() by status_category | sort status_category",
307-
TEST_INDEX_WEBLOGS));
308-
309-
verifySchema(actual, schema("status_category", "string"), schema("count()", "bigint"));
310-
311-
// Should handle the ELSE case for null/non-numeric responses
312-
assertTrue(actual.getJSONArray("datarows").length() > 0);
313-
}
301+
"source=%s | eval age_range = case(age < 30, 'u30', (age >= 35 and age < 40) or age"
302+
+ " >= 80, '30-40 or >=80') | stats avg(balance) by age_range",
303+
TEST_INDEX_BANK));
304+
verifySchema(actual4, schema("avg(balance)", "double"), schema("age_range", "string"));
305+
verifyDataRows(
306+
actual4,
307+
rows(32838.0, "u30"),
308+
rows(30497.0, "null"),
309+
rows(20881.333333333332, "30-40 or >=80"));
314310

315-
@Test
316-
public void testNonOptimizableCaseExpression() throws IOException {
317-
// Test CASE that cannot be optimized (different fields)
318-
JSONObject actual =
311+
// 1.5 Should not be pushed because the range is not closed-open
312+
JSONObject actual5 =
319313
executeQuery(
320314
String.format(
321-
"source=%s | eval mixed_condition = case("
322-
+ " cast(bytes as int) < 1000, 'small_bytes',"
323-
+ " cast(response as int) >= 400, 'error_response'"
324-
+ " else 'other'"
325-
+ ") | stats count() by mixed_condition",
326-
TEST_INDEX_WEBLOGS));
327-
328-
verifySchema(actual, schema("mixed_condition", "string"), schema("count()", "bigint"));
315+
"source=%s | eval age_range = case(age < 30, 'u30', age >= 30 and age <= 40, 'u40'"
316+
+ " else 'u100') | stats avg(age) as avg_age by age_range",
317+
TEST_INDEX_BANK));
318+
verifySchema(actual5, schema("avg_age", "double"), schema("age_range", "string"));
319+
verifyDataRows(actual5, rows(35.0, "u40"), rows(28.0, "u30"));
329320

330-
// This should work but won't be optimized
331-
assertTrue(actual.getJSONArray("datarows").length() > 0);
332-
}
333-
334-
@Test
335-
public void testCaseWithNonLiteralResult() throws IOException {
336-
// Test CASE that cannot be optimized (non-literal results)
337-
JSONObject actual =
321+
// CASE 2: Composite - Range - Metric
322+
// 2.1 Composite (term) - Range - Metric
323+
JSONObject actual6 =
338324
executeQuery(
339325
String.format(
340-
"source=%s | eval computed_result = case("
341-
+ " cast(bytes as int) < 1000, concat('small_', host),"
342-
+ " cast(bytes as int) >= 1000, concat('large_', host)"
343-
+ ") | stats count() by computed_result | head 3",
344-
TEST_INDEX_WEBLOGS));
345-
346-
verifySchema(actual, schema("computed_result", "string"), schema("count()", "bigint"));
326+
"source=%s | eval age_range = case(age < 30, 'u30' else 'a30') | stats avg(balance)"
327+
+ " by state, age_range",
328+
TEST_INDEX_BANK));
329+
verifySchema(
330+
actual6,
331+
schema("avg(balance)", "double"),
332+
schema("state", "string"),
333+
schema("age_range", "string"));
334+
verifyDataRows(
335+
actual6,
336+
rows(39225.0, "IL", "a30"),
337+
rows(48086.0, "IN", "a30"),
338+
rows(4180.0, "MD", "a30"),
339+
rows(40540.0, "PA", "a30"),
340+
rows(5686.0, "TN", "a30"),
341+
rows(32838.0, "VA", "u30"),
342+
rows(16418.0, "WA", "a30"));
347343

348-
// This should work but won't be optimized to range aggregation
349-
assertTrue(actual.getJSONArray("datarows").length() > 0);
350-
}
344+
// 2.2 Composite (date histogram) - Range - Metric
345+
JSONObject actual7 =
346+
executeQuery(
347+
"source=opensearch-sql_test_index_time_data | eval value_range = case(value < 7000,"
348+
+ " 'small' else 'large') | stats avg(value) by value_range, span(@timestamp,"
349+
+ " 1h)");
350+
verifySchema(
351+
actual7,
352+
schema("avg(value)", "double"),
353+
schema("span(@timestamp,1h)", "timestamp"),
354+
schema("value_range", "string"));
355+
// Verify we have results with both small and large ranges and timestamps
356+
assertTrue(actual7.getJSONArray("datarows").length() == 100);
357+
// Verify some sample rows to check data correctness
358+
String resultStr = actual7.toString();
359+
assertTrue(resultStr.contains("small") && resultStr.contains("large"));
360+
assertTrue(resultStr.contains("2025-07-28") && resultStr.contains("2025-07-29"));
351361

352-
@Test
353-
public void testOptimizableCaseRangeAggregation() throws IOException {
354-
// Test CASE that could be optimized if all ranges are covered with explicit ELSE
355-
JSONObject actual =
362+
// 2.3 Composite(2 fields) - Range - Metric (with count)
363+
JSONObject actual8 =
356364
executeQuery(
357365
String.format(
358-
"source=%s | eval size_bucket = case("
359-
+ " cast(bytes as int) < 2000, 'small',"
360-
+ " cast(bytes as int) >= 2000 AND cast(bytes as int) < 5000, 'medium',"
361-
+ " cast(bytes as int) >= 5000, 'large'"
362-
+ " else 'unknown'"
363-
+ ") | stats count() by size_bucket | sort size_bucket",
364-
TEST_INDEX_WEBLOGS));
366+
"source=%s | eval age_range = case(age < 30, 'u30' else 'a30') | stats"
367+
+ " avg(balance), count() by age_range, state, gender",
368+
TEST_INDEX_BANK));
369+
verifySchema(
370+
actual8,
371+
schema("avg(balance)", "double"),
372+
schema("count()", "bigint"),
373+
schema("age_range", "string"),
374+
schema("state", "string"),
375+
schema("gender", "string"));
376+
verifyDataRows(
377+
actual8,
378+
rows(5686.0, 1, "a30", "TN", "M"),
379+
rows(16418.0, 1, "a30", "WA", "M"),
380+
rows(40540.0, 1, "a30", "PA", "F"),
381+
rows(4180.0, 1, "a30", "MD", "M"),
382+
rows(32838.0, 1, "u30", "VA", "F"),
383+
rows(39225.0, 1, "a30", "IL", "M"),
384+
rows(48086.0, 1, "a30", "IN", "F"));
365385

366-
verifySchema(actual, schema("size_bucket", "string"), schema("count()", "bigint"));
386+
// 2.4 Composite (2 fields) - Range - Range - Metric (with count)
387+
JSONObject actual9 =
388+
executeQuery(
389+
String.format(
390+
"source=%s | eval age_range = case(age < 35, 'u35' else 'a35'), balance_range ="
391+
+ " case(balance < 20000, 'medium' else 'high') | stats avg(balance) as"
392+
+ " avg_balance by age_range, balance_range, state",
393+
TEST_INDEX_BANK));
394+
verifySchema(
395+
actual9,
396+
schema("avg_balance", "double"),
397+
schema("age_range", "string"),
398+
schema("balance_range", "string"),
399+
schema("state", "string"));
400+
verifyDataRows(
401+
actual9,
402+
rows(39225.0, "u35", "high", "IL"),
403+
rows(48086.0, "u35", "high", "IN"),
404+
rows(4180.0, "u35", "medium", "MD"),
405+
rows(40540.0, "a35", "high", "PA"),
406+
rows(5686.0, "a35", "medium", "TN"),
407+
rows(32838.0, "u35", "high", "VA"),
408+
rows(16418.0, "a35", "medium", "WA"));
367409

368-
// This should work - the explicit ELSE makes it potentially optimizable
369-
assertTrue(actual.getJSONArray("datarows").length() > 0);
410+
// 2.5 Should not be pushed because case result expression is not constant
411+
JSONObject actual10 =
412+
executeQuery(
413+
String.format(
414+
"source=%s | eval age_range = case(age < 35, 'u35' else email) | stats avg(balance)"
415+
+ " as avg_balance by age_range, state",
416+
TEST_INDEX_BANK));
417+
verifySchema(
418+
actual10,
419+
schema("avg_balance", "double"),
420+
schema("age_range", "string"),
421+
schema("state", "string"));
422+
verifyDataRows(
423+
actual10,
424+
rows(32838.0, "u35", "VA"),
425+
rows(4180.0, "u35", "MD"),
426+
rows(48086.0, "u35", "IN"),
427+
rows(40540.0, "virginiaayala@filodyne.com", "PA"),
428+
rows(39225.0, "u35", "IL"),
429+
rows(5686.0, "hattiebond@netagy.com", "TN"),
430+
rows(16418.0, "elinorratliff@scentric.com", "WA"));
370431
}
371432
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalProject(avg(value)=[$2], span(@timestamp,1h)=[$1], value_range=[$0])
5+
LogicalAggregate(group=[{0, 2}], avg(value)=[AVG($1)])
6+
LogicalProject(value_range=[$10], value=[$2], span(@timestamp,1h)=[SPAN($0, 1, 'h')])
7+
LogicalFilter(condition=[IS NOT NULL($0)])
8+
LogicalProject(@timestamp=[$0], category=[$1], value=[$2], timestamp=[$3], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'large':VARCHAR)])
9+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]])
10+
physical: |
11+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]], PushDownContext=[[FILTER->IS NOT NULL($0), AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0, 2},avg(value)=AVG($1)), PROJECT->[avg(value), span(@timestamp,1h), value_range], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","query":{"exists":{"field":"@timestamp","boost":1.0}},"sort":[],"aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"span(@timestamp,1h)":{"date_histogram":{"field":"@timestamp","missing_bucket":false,"order":"asc","fixed_interval":"1h"}}}]},"aggregations":{"value_range":{"range":{"field":"value","ranges":[{"key":"small","to":7000.0},{"key":"large","from":7000.0}],"keyed":true},"aggregations":{"avg(value)":{"avg":{"field":"value"}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalProject(avg(value)=[$2], span(@timestamp,1h)=[$1], value_range=[$0])
5+
LogicalAggregate(group=[{0, 2}], avg(value)=[AVG($1)])
6+
LogicalProject(value_range=[$10], value=[$2], span(@timestamp,1h)=[SPAN($0, 1, 'h')])
7+
LogicalFilter(condition=[IS NOT NULL($0)])
8+
LogicalProject(@timestamp=[$0], category=[$1], value=[$2], timestamp=[$3], _id=[$4], _index=[$5], _score=[$6], _maxscore=[$7], _sort=[$8], _routing=[$9], value_range=[CASE(<($2, 7000), 'small':VARCHAR, 'large':VARCHAR)])
9+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]])
10+
physical: |
11+
EnumerableLimit(fetch=[10000])
12+
EnumerableCalc(expr#0..3=[{inputs}], expr#4=[0], expr#5=[=($t3, $t4)], expr#6=[null:BIGINT], expr#7=[CASE($t5, $t6, $t2)], expr#8=[CAST($t7):DOUBLE], expr#9=[/($t8, $t3)], avg(value)=[$t9], span(@timestamp,1h)=[$t1], value_range=[$t0])
13+
EnumerableAggregate(group=[{0, 2}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)])
14+
EnumerableCalc(expr#0..9=[{inputs}], expr#10=[7000], expr#11=[<($t2, $t10)], expr#12=['small':VARCHAR], expr#13=['large':VARCHAR], expr#14=[CASE($t11, $t12, $t13)], expr#15=[1], expr#16=['h'], expr#17=[SPAN($t0, $t15, $t16)], expr#18=[IS NOT NULL($t0)], value_range=[$t14], value=[$t2], span(@timestamp,1h)=[$t17], $condition=[$t18])
15+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_time_data]])

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/LeafBucketAggregationParser.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import org.opensearch.search.aggregations.bucket.range.Range;
2020

2121
/**
22-
* Use BucketAggregationParser only when there is a single group-by key, it returns multiple
23-
* buckets. {@link CompositeAggregationParser} is used for multiple group by keys
22+
* Use LeafBucketAggregationParser only when there is a single group-by key, it returns multiple
23+
* buckets. {@link BucketAggregationParser} is used for multiple group by keys
2424
*/
2525
@EqualsAndHashCode
2626
public class LeafBucketAggregationParser implements OpenSearchAggregationResponseParser {
@@ -36,7 +36,7 @@ public LeafBucketAggregationParser(List<MetricParser> metricParserList) {
3636
metricsParser = new MetricParserHelper(metricParserList);
3737
}
3838

39-
/** CompositeAggregationParser with count aggregation name list, used in v3 */
39+
/** BucketAggregationParser with count aggregation name list, used in v3 */
4040
public LeafBucketAggregationParser(
4141
List<MetricParser> metricParserList, List<String> countAggNameList) {
4242
metricsParser = new MetricParserHelper(metricParserList);

0 commit comments

Comments
 (0)