Skip to content

Commit 4bbb0f2

Browse files
committed
Create scaffold for type checking of aggregation functions
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent ed4b3ed commit 4bbb0f2

3 files changed

Lines changed: 119 additions & 60 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,27 +77,42 @@ public class UserDefinedFunctionUtils {
7777
public static Set<String> MULTI_FIELDS_RELEVANCE_FUNCTION_SET =
7878
ImmutableSet.of("simple_query_string", "query_string", "multi_match");
7979

80-
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
81-
Class<? extends UserDefinedAggFunction> UDAF,
80+
public static SqlUserDefinedAggFunction createUserDefinedAggFunction(
81+
Class<? extends UserDefinedAggFunction<?>> udafClass,
8282
String functionName,
83-
SqlReturnTypeInference returnType,
83+
SqlReturnTypeInference returnType) {
84+
return new SqlUserDefinedAggFunction(
85+
new SqlIdentifier(functionName, SqlParserPos.ZERO),
86+
SqlKind.OTHER_FUNCTION,
87+
returnType,
88+
null,
89+
null,
90+
AggregateFunctionImpl.create(udafClass),
91+
false,
92+
false,
93+
Optionality.FORBIDDEN);
94+
}
95+
96+
public static RelBuilder.AggCall convertUDAFToAggCall(
97+
SqlUserDefinedAggFunction udaf,
8498
List<RexNode> fields,
8599
List<RexNode> argList,
86100
RelBuilder relBuilder) {
87-
SqlUserDefinedAggFunction sqlUDAF =
88-
new SqlUserDefinedAggFunction(
89-
new SqlIdentifier(functionName, SqlParserPos.ZERO),
90-
SqlKind.OTHER_FUNCTION,
91-
returnType,
92-
null,
93-
null,
94-
AggregateFunctionImpl.create(UDAF),
95-
false,
96-
false,
97-
Optionality.FORBIDDEN);
98101
List<RexNode> addArgList = new ArrayList<>(fields);
99102
addArgList.addAll(argList);
100-
return relBuilder.aggregateCall(sqlUDAF, addArgList);
103+
return relBuilder.aggregateCall(udaf, addArgList);
104+
}
105+
106+
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
107+
Class<? extends UserDefinedAggFunction<?>> udafClass,
108+
String functionName,
109+
SqlReturnTypeInference returnType,
110+
List<RexNode> fields,
111+
List<RexNode> argList,
112+
RelBuilder relBuilder) {
113+
SqlUserDefinedAggFunction udaf =
114+
createUserDefinedAggFunction(udafClass, functionName, returnType);
115+
return convertUDAFToAggCall(udaf, fields, argList, relBuilder);
101116
}
102117

103118
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@
250250
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
251251
import org.apache.calcite.sql.type.SqlTypeFamily;
252252
import org.apache.calcite.sql.type.SqlTypeName;
253+
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
253254
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
254255
import org.apache.calcite.tools.RelBuilder;
255256
import org.apache.commons.lang3.tuple.Pair;
@@ -337,14 +338,16 @@ default RexNode resolve(RexBuilder builder, RexNode... args) {
337338
* implementations are independent of any specific data storage, should be registered here
338339
* internally.
339340
*/
340-
private final ImmutableMap<BuiltinFunctionName, AggHandler> aggFunctionRegistry;
341+
private final ImmutableMap<BuiltinFunctionName, Pair<CalciteFuncSignature, AggHandler>>
342+
aggFunctionRegistry;
341343

342344
/**
343345
* The external agg function registry. Agg Functions whose implementations depend on a specific
344346
* data engine should be registered here. This reduces coupling between the core module and
345347
* particular storage backends.
346348
*/
347-
private final Map<BuiltinFunctionName, AggHandler> aggExternalFunctionRegistry;
349+
private final Map<BuiltinFunctionName, Pair<CalciteFuncSignature, AggHandler>>
350+
aggExternalFunctionRegistry;
348351

349352
private PPLFuncImpTable(Builder builder, AggBuilder aggBuilder) {
350353
final ImmutableMap.Builder<BuiltinFunctionName, List<Pair<CalciteFuncSignature, FunctionImp>>>
@@ -353,8 +356,8 @@ private PPLFuncImpTable(Builder builder, AggBuilder aggBuilder) {
353356
this.functionRegistry = ImmutableMap.copyOf(mapBuilder.build());
354357
this.externalFunctionRegistry = new ConcurrentHashMap<>();
355358

356-
final ImmutableMap.Builder<BuiltinFunctionName, AggHandler> aggMapBuilder =
357-
ImmutableMap.builder();
359+
final ImmutableMap.Builder<BuiltinFunctionName, Pair<CalciteFuncSignature, AggHandler>>
360+
aggMapBuilder = ImmutableMap.builder();
358361
aggBuilder.map.forEach(aggMapBuilder::put);
359362
this.aggFunctionRegistry = ImmutableMap.copyOf(aggMapBuilder.build());
360363
this.aggExternalFunctionRegistry = new ConcurrentHashMap<>();
@@ -370,7 +373,7 @@ public void registerExternalOperator(BuiltinFunctionName functionName, SqlOperat
370373
PPLTypeChecker typeChecker =
371374
wrapSqlOperandTypeChecker(
372375
operator.getOperandTypeChecker(),
373-
operator.getName(),
376+
functionName.name(),
374377
operator instanceof SqlUserDefinedFunction);
375378
CalciteFuncSignature signature = new CalciteFuncSignature(functionName.getName(), typeChecker);
376379
externalFunctionRegistry.compute(
@@ -384,14 +387,22 @@ public void registerExternalOperator(BuiltinFunctionName functionName, SqlOperat
384387
}
385388

386389
/**
387-
* Register a function implementation from external services dynamically.
390+
* Register an external aggregate operator dynamically.
388391
*
389392
* @param functionName the name of the function, has to be defined in BuiltinFunctionName
390-
* @param functionImp the implementation of the agg function
393+
* @param aggFunction a SqlUserDefinedAggFunction representing the aggregate function
394+
* implementation
391395
*/
392-
public void registerExternalAggFunction(
393-
BuiltinFunctionName functionName, AggHandler functionImp) {
394-
aggExternalFunctionRegistry.put(functionName, functionImp);
396+
public void registerExternalAggOperator(
397+
BuiltinFunctionName functionName, SqlUserDefinedAggFunction aggFunction) {
398+
PPLTypeChecker typeChecker =
399+
wrapSqlOperandTypeChecker(aggFunction.getOperandTypeChecker(), functionName.name(), true);
400+
CalciteFuncSignature signature = new CalciteFuncSignature(functionName.getName(), typeChecker);
401+
AggHandler handler =
402+
(distinct, field, argList, ctx) ->
403+
UserDefinedFunctionUtils.convertUDAFToAggCall(
404+
aggFunction, List.of(field), argList, ctx.relBuilder);
405+
aggExternalFunctionRegistry.put(functionName, Pair.of(signature, handler));
395406
}
396407

397408
public RelBuilder.AggCall resolveAgg(
@@ -400,13 +411,24 @@ public RelBuilder.AggCall resolveAgg(
400411
RexNode field,
401412
List<RexNode> argList,
402413
CalcitePlanContext context) {
403-
AggHandler handler = aggExternalFunctionRegistry.get(functionName);
404-
if (handler == null) {
405-
handler = aggFunctionRegistry.get(functionName);
414+
var implementation = aggExternalFunctionRegistry.get(functionName);
415+
if (implementation == null) {
416+
implementation = aggFunctionRegistry.get(functionName);
406417
}
407-
if (handler == null) {
418+
if (implementation == null) {
408419
throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName));
409420
}
421+
CalciteFuncSignature signature = implementation.getKey();
422+
RelDataType fieldType = field.getType();
423+
if (!signature.match(functionName.getName(), List.of(fieldType))) {
424+
throw new ExpressionEvaluationException(
425+
String.format(
426+
"Aggregation function %s expects field type {%s}, but got %s",
427+
functionName,
428+
signature.typeChecker().getAllowedSignatures(),
429+
getActualSignature(List.of(fieldType))));
430+
}
431+
var handler = implementation.getValue();
410432
return handler.apply(distinct, field, argList, context);
411433
}
412434

@@ -1037,43 +1059,66 @@ void register(
10371059
}
10381060

10391061
private static class AggBuilder {
1040-
private final Map<BuiltinFunctionName, AggHandler> map = new HashMap<>();
1062+
private final Map<BuiltinFunctionName, Pair<CalciteFuncSignature, AggHandler>> map =
1063+
new HashMap<>();
10411064

1042-
void register(BuiltinFunctionName functionName, AggHandler aggHandler) {
1043-
map.put(functionName, aggHandler);
1065+
void register(
1066+
BuiltinFunctionName functionName, AggHandler aggHandler, PPLTypeChecker typeChecker) {
1067+
CalciteFuncSignature signature =
1068+
new CalciteFuncSignature(functionName.getName(), typeChecker);
1069+
map.put(functionName, Pair.of(signature, aggHandler));
1070+
}
1071+
1072+
void registerOperator(BuiltinFunctionName functionName, SqlUserDefinedAggFunction aggFunction) {
1073+
PPLTypeChecker typeChecker =
1074+
wrapSqlOperandTypeChecker(aggFunction.getOperandTypeChecker(), functionName.name(), true);
1075+
AggHandler handler =
1076+
(distinct, field, argList, ctx) ->
1077+
UserDefinedFunctionUtils.convertUDAFToAggCall(
1078+
aggFunction, List.of(field), argList, ctx.relBuilder);
1079+
register(functionName, handler, typeChecker);
10441080
}
10451081

10461082
void populate() {
1047-
register(MAX, (distinct, field, argList, ctx) -> ctx.relBuilder.max(field));
1048-
register(MIN, (distinct, field, argList, ctx) -> ctx.relBuilder.min(field));
1083+
register(MAX, (distinct, field, argList, ctx) -> ctx.relBuilder.max(field), null);
1084+
register(MIN, (distinct, field, argList, ctx) -> ctx.relBuilder.min(field), null);
10491085

1050-
register(AVG, (distinct, field, argList, ctx) -> ctx.relBuilder.avg(distinct, null, field));
1086+
register(
1087+
AVG, (distinct, field, argList, ctx) -> ctx.relBuilder.avg(distinct, null, field), null);
10511088

10521089
register(
10531090
COUNT,
10541091
(distinct, field, argList, ctx) ->
10551092
ctx.relBuilder.count(
1056-
distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field)));
1057-
register(SUM, (distinct, field, argList, ctx) -> ctx.relBuilder.sum(distinct, null, field));
1093+
distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field)),
1094+
null);
1095+
register(
1096+
SUM,
1097+
(distinct, field, argList, ctx) ->
1098+
ctx.relBuilder.aggregateCall(SqlStdOperatorTable.SUM, field),
1099+
null);
10581100

10591101
register(
10601102
VARSAMP,
1061-
(distinct, field, argList, ctx) ->
1062-
ctx.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field));
1103+
(distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field),
1104+
null);
10631105

10641106
register(
10651107
VARPOP,
1066-
(distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_POP_NULLABLE, field));
1108+
(distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_POP_NULLABLE, field),
1109+
null);
10671110

10681111
register(
10691112
STDDEV_SAMP,
10701113
(distinct, field, argList, ctx) ->
1071-
ctx.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field));
1114+
ctx.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field),
1115+
null);
10721116

10731117
register(
10741118
STDDEV_POP,
10751119
(distinct, field, argList, ctx) ->
1076-
ctx.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field));
1120+
ctx.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field),
1121+
null);
10771122

10781123
register(
10791124
TAKE,
@@ -1087,7 +1132,8 @@ void populate() {
10871132
List.of(field),
10881133
newArgList,
10891134
ctx.relBuilder);
1090-
});
1135+
},
1136+
null);
10911137

10921138
register(
10931139
PERCENTILE_APPROX,
@@ -1102,7 +1148,8 @@ void populate() {
11021148
List.of(field),
11031149
newArgList,
11041150
ctx.relBuilder);
1105-
});
1151+
},
1152+
null);
11061153

11071154
register(
11081155
INTERNAL_PATTERN,
@@ -1113,7 +1160,8 @@ void populate() {
11131160
ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList),
11141161
List.of(field),
11151162
argList,
1116-
ctx.relBuilder));
1163+
ctx.relBuilder),
1164+
null);
11171165
}
11181166
}
11191167
}

opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55

66
package org.opensearch.sql.opensearch.executor;
77

8-
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.convertRelDataTypeToExprType;
9-
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction;
10-
import static org.opensearch.sql.expression.function.BuiltinFunctionName.DISTINCT_COUNT_APPROX;
11-
128
import java.security.AccessController;
139
import java.security.PrivilegedAction;
1410
import java.sql.PreparedStatement;
@@ -29,12 +25,15 @@
2925
import org.apache.calcite.sql.SqlExplainLevel;
3026
import org.apache.calcite.sql.type.ReturnTypes;
3127
import org.apache.calcite.sql.type.SqlTypeName;
28+
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
3229
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
3330
import org.apache.logging.log4j.LogManager;
3431
import org.apache.logging.log4j.Logger;
3532
import org.opensearch.sql.ast.statement.Explain.ExplainFormat;
3633
import org.opensearch.sql.calcite.CalcitePlanContext;
3734
import org.opensearch.sql.calcite.utils.CalciteToolsHelper.OpenSearchRelRunners;
35+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
36+
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
3837
import org.opensearch.sql.common.response.ResponseListener;
3938
import org.opensearch.sql.data.model.ExprTupleValue;
4039
import org.opensearch.sql.data.model.ExprValue;
@@ -255,7 +254,7 @@ private void buildResultSet(
255254
exprType = ExprCoreType.UNDEFINED;
256255
}
257256
} else {
258-
exprType = convertRelDataTypeToExprType(fieldType);
257+
exprType = OpenSearchTypeFactory.convertRelDataTypeToExprType(fieldType);
259258
}
260259
columns.add(new Column(columnName, null, exprType));
261260
}
@@ -276,15 +275,12 @@ private void registerOpenSearchFunctions() {
276275
client.getClass().getName());
277276
}
278277

279-
PPLFuncImpTable.INSTANCE.registerExternalAggFunction(
280-
DISTINCT_COUNT_APPROX,
281-
(distinct, field, argList, ctx) ->
282-
TransferUserDefinedAggFunction(
283-
DistinctCountApproxAggFunction.class,
284-
"APPROX_DISTINCT_COUNT",
285-
ReturnTypes.BIGINT_FORCE_NULLABLE,
286-
List.of(field),
287-
argList,
288-
ctx.relBuilder));
278+
SqlUserDefinedAggFunction approxDistinctCountFunction =
279+
UserDefinedFunctionUtils.createUserDefinedAggFunction(
280+
DistinctCountApproxAggFunction.class,
281+
"APPROX_DISTINCT_COUNT",
282+
ReturnTypes.BIGINT_FORCE_NULLABLE);
283+
PPLFuncImpTable.INSTANCE.registerExternalAggOperator(
284+
BuiltinFunctionName.DISTINCT_COUNT_APPROX, approxDistinctCountFunction);
289285
}
290286
}

0 commit comments

Comments
 (0)