Skip to content

Commit cbcb25a

Browse files
committed
Fix unit tests
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent f416dee commit cbcb25a

4 files changed

Lines changed: 22 additions & 18 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ public static Pair<List<AggregationBuilder>, OpenSearchAggregationResponseParser
204204
// Find group by fields derived from CASE functions and convert them to range queries
205205
List<Pair<Integer, RangeAggregationBuilder>> groupsByCase =
206206
groupList.stream()
207+
.filter(i -> project != null && i < project.getProjects().size())
207208
.map(i -> Pair.of(i, project.getNamedProjects().get(i)))
208209
.filter(
209210
p ->

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/BucketAggregationParser.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
import java.util.List;
1111
import java.util.Map;
1212
import java.util.stream.Collectors;
13+
import lombok.Getter;
1314
import org.opensearch.search.aggregations.Aggregation;
1415
import org.opensearch.search.aggregations.Aggregations;
1516
import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation;
1617
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation;
1718

1819
public class BucketAggregationParser implements OpenSearchAggregationResponseParser {
19-
private final OpenSearchAggregationResponseParser subAggParser;
20+
@Getter private final OpenSearchAggregationResponseParser subAggParser;
2021

2122
public BucketAggregationParser(OpenSearchAggregationResponseParser subAggParser) {
2223
this.subAggParser = subAggParser;

opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/LeafBucketAggregationParser.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77

88
import java.util.Arrays;
99
import java.util.HashMap;
10+
import java.util.LinkedHashMap;
1011
import java.util.List;
1112
import java.util.Map;
1213
import java.util.Objects;
1314
import java.util.stream.Collectors;
1415
import lombok.EqualsAndHashCode;
16+
import lombok.Getter;
1517
import org.opensearch.search.aggregations.Aggregation;
1618
import org.opensearch.search.aggregations.Aggregations;
1719
import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation;
@@ -24,7 +26,7 @@
2426
*/
2527
@EqualsAndHashCode
2628
public class LeafBucketAggregationParser implements OpenSearchAggregationResponseParser {
27-
private final MetricParserHelper metricsParser;
29+
@Getter private final MetricParserHelper metricsParser;
2830

2931
public LeafBucketAggregationParser(MetricParser... metricParserList) {
3032
metricsParser = new MetricParserHelper(Arrays.asList(metricParserList));
@@ -47,25 +49,19 @@ public List<Map<String, Object>> parse(Aggregations aggregations) {
4749
private Map<String, Object> parse(MultiBucketsAggregation.Bucket bucket, String name) {
4850
if (bucket instanceof CompositeAggregation.Bucket compositeBucket) {
4951
return parse(compositeBucket);
50-
} else if (bucket instanceof Range.Bucket rangeBucket) {
51-
return parse(rangeBucket, name);
5252
}
53-
return metricsParser.parse(bucket.getAggregations());
54-
}
55-
56-
private Map<String, Object> parse(CompositeAggregation.Bucket bucket) {
57-
Map<String, Object> resultMap = new HashMap<>();
58-
resultMap.putAll(bucket.getKey());
53+
if (bucket instanceof Range.Bucket && bucket.getDocCount() == 0) {
54+
return null;
55+
}
56+
Map<String, Object> resultMap = new LinkedHashMap<>();
57+
resultMap.put(name, bucket.getKey());
5958
resultMap.putAll(metricsParser.parse(bucket.getAggregations()));
6059
return resultMap;
6160
}
6261

63-
private Map<String, Object> parse(Range.Bucket bucket, String name) {
64-
if (bucket.getDocCount() == 0) {
65-
return null;
66-
}
62+
private Map<String, Object> parse(CompositeAggregation.Bucket bucket) {
6763
Map<String, Object> resultMap = new HashMap<>();
68-
resultMap.put(name, bucket.getKey());
64+
resultMap.putAll(bucket.getKey());
6965
resultMap.putAll(metricsParser.parse(bucket.getAggregations()));
7066
return resultMap;
7167
}

opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType;
4747
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType.MappingType;
4848
import org.opensearch.sql.opensearch.request.AggregateAnalyzer.ExpressionNotAnalyzableException;
49-
import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser;
49+
import org.opensearch.sql.opensearch.response.agg.BucketAggregationParser;
5050
import org.opensearch.sql.opensearch.response.agg.FilterParser;
51+
import org.opensearch.sql.opensearch.response.agg.LeafBucketAggregationParser;
5152
import org.opensearch.sql.opensearch.response.agg.MetricParserHelper;
5253
import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser;
5354
import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser;
@@ -281,9 +282,11 @@ void analyze_groupBy() throws ExpressionNotAnalyzableException {
281282
+ "{\"b\":{\"terms\":{\"field\":\"b.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},"
282283
+ "\"aggregations\":{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}}}}]",
283284
result.getLeft().toString());
284-
assertInstanceOf(CompositeAggregationParser.class, result.getRight());
285+
assertInstanceOf(BucketAggregationParser.class, result.getRight());
285286
MetricParserHelper metricsParser =
286-
((CompositeAggregationParser) result.getRight()).getMetricsParser();
287+
((LeafBucketAggregationParser)
288+
((BucketAggregationParser) result.getRight()).getSubAggParser())
289+
.getMetricsParser();
287290
assertEquals(1, metricsParser.getMetricParserMap().size());
288291
metricsParser
289292
.getMetricParserMap()
@@ -592,8 +595,11 @@ private Project createMockProject(List<Integer> refIndex) {
592595
when(ref.getType()).thenReturn(typeFactory.createSqlType(SqlTypeName.INTEGER));
593596
rexNodes.add(ref);
594597
}
598+
List<org.apache.calcite.util.Pair<RexNode, String>> namedProjects =
599+
rexNodes.stream().map(n -> org.apache.calcite.util.Pair.of(n, n.toString())).toList();
595600
when(project.getProjects()).thenReturn(rexNodes);
596601
when(project.getRowType()).thenReturn(rowType);
602+
when(project.getNamedProjects()).thenReturn(namedProjects);
597603
return project;
598604
}
599605

0 commit comments

Comments
 (0)