Skip to content

Commit 45b49f2

Browse files
committed
pushdown approx distinct count
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 2c41580 commit 45b49f2

6 files changed

Lines changed: 103 additions & 75 deletions

File tree

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,19 @@ public void testExplainOnAggregationWithSumEnhancement() throws IOException {
620620
TEST_INDEX_BANK)));
621621
}
622622

623+
@Test
624+
public void testStatsDistinctCountApproxFunctionExplainWithPushDown() throws IOException {
625+
enabledOnlyWhenPushdownIsEnabled();
626+
String query =
627+
"source=opensearch-sql_test_index_account | stats distinct_count_approx(state) as"
628+
+ " distinct_states by gender";
629+
var result = explainQueryToString(query);
630+
String expected =
631+
loadFromFile(
632+
"expectedOutput/calcite/explain_agg_with_distinct_count_approx_enhancement.json");
633+
assertJsonEqualsIgnoreId(expected, result);
634+
}
635+
623636
@Test
624637
public void testExplainRegexMatchInWhereWithScriptPushdown() throws IOException {
625638
enabledOnlyWhenPushdownIsEnabled();
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"calcite":{
3+
"logical":"LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(distinct_states=[$1], gender=[$0])\n LogicalAggregate(group=[{0}], distinct_states=[DISTINCT_COUNT_APPROX($1)])\n LogicalProject(gender=[$4], state=[$7])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4+
"physical":"CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},distinct_states=DISTINCT_COUNT_APPROX($1)), PROJECT->[distinct_states, gender], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"composite_buckets\":{\"composite\":{\"size\":1000,\"sources\":[{\"gender\":{\"terms\":{\"field\":\"gender.keyword\",\"missing_bucket\":true,\"missing_order\":\"first\",\"order\":\"asc\"}}}]},\"aggregations\":{\"distinct_states\":{\"cardinality\":{\"field\":\"state.keyword\"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
5+
}
6+
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"calcite": {
3-
"logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], distinct_states=[APPROX_DISTINCT_COUNT($7) OVER ()])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4-
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableWindow(window#0=[window(aggs [APPROX_DISTINCT_COUNT($7)])])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"balance\",\"gender\",\"city\",\"employer\",\"state\",\"age\",\"email\",\"lastname\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
3+
"logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], distinct_states=[DISTINCT_COUNT_APPROX($7) OVER ()])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4+
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableWindow(window#0=[window(aggs [DISTINCT_COUNT_APPROX($7)])])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"balance\",\"gender\",\"city\",\"employer\",\"state\",\"age\",\"email\",\"lastname\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
55
}
66
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"calcite": {
3-
"logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], distinct_states=[APPROX_DISTINCT_COUNT($7) OVER (PARTITION BY $4)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4-
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableWindow(window#0=[window(partition {4} aggs [APPROX_DISTINCT_COUNT($7)])])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"balance\",\"gender\",\"city\",\"employer\",\"state\",\"age\",\"email\",\"lastname\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
3+
"logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], distinct_states=[DISTINCT_COUNT_APPROX($7) OVER (PARTITION BY $4)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
4+
"physical": "EnumerableLimit(fetch=[10000])\n EnumerableWindow(window#0=[window(partition {4} aggs [DISTINCT_COUNT_APPROX($7)])])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"account_number\",\"firstname\",\"address\",\"balance\",\"gender\",\"city\",\"employer\",\"state\",\"age\",\"email\",\"lastname\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n"
55
}
66
}

opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.sql.opensearch.executor;
77

8+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DISTINCT_COUNT_APPROX;
9+
810
import java.security.AccessController;
911
import java.security.PrivilegedAction;
1012
import java.sql.PreparedStatement;
@@ -292,10 +294,10 @@ private void registerOpenSearchFunctions() {
292294
SqlUserDefinedAggFunction approxDistinctCountFunction =
293295
UserDefinedFunctionUtils.createUserDefinedAggFunction(
294296
DistinctCountApproxAggFunction.class,
295-
"APPROX_DISTINCT_COUNT",
297+
DISTINCT_COUNT_APPROX.toString(),
296298
ReturnTypes.BIGINT_FORCE_NULLABLE,
297299
null);
298300
PPLFuncImpTable.INSTANCE.registerExternalAggOperator(
299-
BuiltinFunctionName.DISTINCT_COUNT_APPROX, approxDistinctCountFunction);
301+
DISTINCT_COUNT_APPROX, approxDistinctCountFunction);
300302
}
301303
}

opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java

Lines changed: 76 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -382,141 +382,148 @@ private static Pair<AggregationBuilder, MetricParser> createRegularAggregation(
382382
switch (aggCall.getAggregation().kind) {
383383
case AVG:
384384
return Pair.of(
385-
helper.build(args.get(0), AggregationBuilders.avg(aggFieldName)),
386-
new SingleValueParser(aggFieldName));
385+
helper.build(args.get(0), AggregationBuilders.avg(aggFieldName)),
386+
new SingleValueParser(aggFieldName));
387387
case SUM:
388388
// 1. Only case SUM, skip SUM0 / COUNT since calling avg() in DSL should be faster.
389389
// 2. To align with databases, SUM0 is not preferred now.
390390
return Pair.of(
391-
helper.build(args.get(0), AggregationBuilders.sum(aggFieldName)),
392-
new SingleValueParser(aggFieldName));
391+
helper.build(args.get(0), AggregationBuilders.sum(aggFieldName)),
392+
new SingleValueParser(aggFieldName));
393393
case COUNT:
394394
return Pair.of(
395-
helper.build(
396-
!args.isEmpty() ? args.get(0) : null, AggregationBuilders.count(aggFieldName)),
397-
new SingleValueParser(aggFieldName));
395+
helper.build(
396+
!args.isEmpty() ? args.get(0) : null, AggregationBuilders.count(aggFieldName)),
397+
new SingleValueParser(aggFieldName));
398398
case MIN: {
399399
ExprType fieldType =
400-
OpenSearchTypeFactory.convertRelDataTypeToExprType(args.get(0).getType());
400+
OpenSearchTypeFactory.convertRelDataTypeToExprType(args.get(0).getType());
401401
if (supportsMaxMinAggregation(fieldType)) {
402402
return Pair.of(
403-
helper.build(args.get(0), AggregationBuilders.min(aggFieldName)),
404-
new SingleValueParser(aggFieldName));
403+
helper.build(args.get(0), AggregationBuilders.min(aggFieldName)),
404+
new SingleValueParser(aggFieldName));
405405
} else {
406406
return Pair.of(
407-
AggregationBuilders.topHits(aggFieldName)
408-
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
409-
.size(1)
410-
.from(0)
411-
.sort(
412-
helper.inferNamedField(args.get(0)).getReferenceForTermQuery(),
413-
SortOrder.ASC),
414-
new TopHitsParser(aggFieldName, true));
407+
AggregationBuilders.topHits(aggFieldName)
408+
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
409+
.size(1)
410+
.from(0)
411+
.sort(
412+
helper.inferNamedField(args.get(0)).getReferenceForTermQuery(),
413+
SortOrder.ASC),
414+
new TopHitsParser(aggFieldName, true));
415415
}
416416
}
417417
case MAX: {
418418
ExprType fieldType =
419-
OpenSearchTypeFactory.convertRelDataTypeToExprType(args.get(0).getType());
419+
OpenSearchTypeFactory.convertRelDataTypeToExprType(args.get(0).getType());
420420
if (supportsMaxMinAggregation(fieldType)) {
421421
return Pair.of(
422-
helper.build(args.get(0), AggregationBuilders.max(aggFieldName)),
423-
new SingleValueParser(aggFieldName));
422+
helper.build(args.get(0), AggregationBuilders.max(aggFieldName)),
423+
new SingleValueParser(aggFieldName));
424424
} else {
425425
return Pair.of(
426-
AggregationBuilders.topHits(aggFieldName)
427-
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
428-
.size(1)
429-
.from(0)
430-
.sort(
431-
helper.inferNamedField(args.get(0)).getReferenceForTermQuery(),
432-
SortOrder.DESC),
433-
new TopHitsParser(aggFieldName, true));
426+
AggregationBuilders.topHits(aggFieldName)
427+
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
428+
.size(1)
429+
.from(0)
430+
.sort(
431+
helper.inferNamedField(args.get(0)).getReferenceForTermQuery(),
432+
SortOrder.DESC),
433+
new TopHitsParser(aggFieldName, true));
434434
}
435435
}
436436
case VAR_SAMP:
437437
return Pair.of(
438-
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
439-
new StatsParser(ExtendedStats::getVarianceSampling, aggFieldName));
438+
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
439+
new StatsParser(ExtendedStats::getVarianceSampling, aggFieldName));
440440
case VAR_POP:
441441
return Pair.of(
442-
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
443-
new StatsParser(ExtendedStats::getVariancePopulation, aggFieldName));
442+
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
443+
new StatsParser(ExtendedStats::getVariancePopulation, aggFieldName));
444444
case STDDEV_SAMP:
445445
return Pair.of(
446-
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
447-
new StatsParser(ExtendedStats::getStdDeviationSampling, aggFieldName));
446+
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
447+
new StatsParser(ExtendedStats::getStdDeviationSampling, aggFieldName));
448448
case STDDEV_POP:
449449
return Pair.of(
450-
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
451-
new StatsParser(ExtendedStats::getStdDeviationPopulation, aggFieldName));
450+
helper.build(args.get(0), AggregationBuilders.extendedStats(aggFieldName)),
451+
new StatsParser(ExtendedStats::getStdDeviationPopulation, aggFieldName));
452452
case ARG_MAX:
453453
return Pair.of(
454-
AggregationBuilders.topHits(aggFieldName)
455-
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
456-
.size(1)
457-
.from(0)
458-
.sort(
459-
helper.inferNamedField(args.get(1)).getRootName(),
460-
org.opensearch.search.sort.SortOrder.DESC),
461-
new ArgMaxMinParser(aggFieldName));
454+
AggregationBuilders.topHits(aggFieldName)
455+
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
456+
.size(1)
457+
.from(0)
458+
.sort(
459+
helper.inferNamedField(args.get(1)).getRootName(),
460+
org.opensearch.search.sort.SortOrder.DESC),
461+
new ArgMaxMinParser(aggFieldName));
462462
case ARG_MIN:
463463
return Pair.of(
464-
AggregationBuilders.topHits(aggFieldName)
465-
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
466-
.size(1)
467-
.from(0)
468-
.sort(
469-
helper.inferNamedField(args.get(1)).getRootName(),
470-
org.opensearch.search.sort.SortOrder.ASC),
471-
new ArgMaxMinParser(aggFieldName));
464+
AggregationBuilders.topHits(aggFieldName)
465+
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
466+
.size(1)
467+
.from(0)
468+
.sort(
469+
helper.inferNamedField(args.get(1)).getRootName(),
470+
org.opensearch.search.sort.SortOrder.ASC),
471+
new ArgMaxMinParser(aggFieldName));
472472
case OTHER_FUNCTION:
473473
BuiltinFunctionName functionName =
474-
BuiltinFunctionName.ofAggregation(aggCall.getAggregation().getName()).get();
474+
BuiltinFunctionName.ofAggregation(aggCall.getAggregation().getName()).get();
475475
switch (functionName) {
476476
case TAKE:
477477
return Pair.of(
478-
AggregationBuilders.topHits(aggFieldName)
479-
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
480-
.size(helper.inferValue(args.get(1), Integer.class))
481-
.from(0),
482-
new TopHitsParser(aggFieldName));
478+
AggregationBuilders.topHits(aggFieldName)
479+
.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null)
480+
.size(helper.inferValue(args.get(1), Integer.class))
481+
.from(0),
482+
new TopHitsParser(aggFieldName));
483483
case FIRST:
484484
TopHitsAggregationBuilder firstBuilder =
485-
AggregationBuilders.topHits(aggFieldName).size(1).from(0);
485+
AggregationBuilders.topHits(aggFieldName).size(1).from(0);
486486
if (!args.isEmpty()) {
487487
firstBuilder.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null);
488488
}
489489
return Pair.of(firstBuilder, new TopHitsParser(aggFieldName, true));
490490
case LAST:
491491
TopHitsAggregationBuilder lastBuilder =
492-
AggregationBuilders.topHits(aggFieldName)
493-
.size(1)
494-
.from(0)
495-
.sort("_doc", org.opensearch.search.sort.SortOrder.DESC);
492+
AggregationBuilders.topHits(aggFieldName)
493+
.size(1)
494+
.from(0)
495+
.sort("_doc", org.opensearch.search.sort.SortOrder.DESC);
496496
if (!args.isEmpty()) {
497497
lastBuilder.fetchSource(helper.inferNamedField(args.get(0)).getRootName(), null);
498498
}
499499
return Pair.of(lastBuilder, new TopHitsParser(aggFieldName, true));
500500
case PERCENTILE_APPROX:
501501
PercentilesAggregationBuilder aggBuilder =
502-
helper
503-
.build(args.get(0), AggregationBuilders.percentiles(aggFieldName))
504-
.percentiles(helper.inferValue(args.get(1), Double.class));
502+
helper
503+
.build(args.get(0), AggregationBuilders.percentiles(aggFieldName))
504+
.percentiles(helper.inferValue(args.get(1), Double.class));
505505
/* See {@link PercentileApproxFunction}, PERCENTILE_APPROX accepts args of [FIELD, PERCENTILE, TYPE, COMPRESSION(optional)] */
506506
if (args.size() > 3) {
507507
aggBuilder.compression(helper.inferValue(args.get(3), Double.class));
508508
}
509509
return Pair.of(aggBuilder, new SinglePercentileParser(aggFieldName));
510+
case DISTINCT_COUNT_APPROX:
511+
return Pair.of(
512+
helper.build(
513+
!args.isEmpty() ? args.getFirst() : null,
514+
AggregationBuilders.cardinality(aggFieldName)),
515+
new SingleValueParser(aggFieldName));
510516
default:
511517
throw new AggregateAnalyzer.AggregateAnalyzerException(
512-
String.format("Unsupported push-down aggregator %s", aggCall.getAggregation()));
518+
String.format("Unsupported push-down aggregator %s", aggCall.getAggregation()));
513519
}
514520
default:
515521
throw new AggregateAnalyzer.AggregateAnalyzerException(
516-
String.format("unsupported aggregator %s", aggCall.getAggregation()));
522+
String.format("unsupported aggregator %s", aggCall.getAggregation()));
517523
}
518524
}
519525

526+
520527
private static boolean supportsMaxMinAggregation(ExprType fieldType) {
521528
ExprType coreType =
522529
(fieldType instanceof OpenSearchDataType)

0 commit comments

Comments
 (0)