Skip to content

Commit 7f26f9d

Browse files
committed
Add type checkers for aggregation functions
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 4bbb0f2 commit 7f26f9d

3 files changed

Lines changed: 89 additions & 32 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/udf/udaf/TakeAggFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public Object result(TakeAccumulator accumulator) {
2424
@Override
2525
public TakeAccumulator add(TakeAccumulator acc, Object... values) {
2626
Object candidateValue = values[0];
27-
int size = 0;
27+
int size;
2828
if (values.length > 1) {
2929
size = (int) values[1];
3030
} else {

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.calcite.rex.RexCall;
3131
import org.apache.calcite.rex.RexNode;
3232
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
33+
import org.apache.calcite.sql.SqlAggFunction;
3334
import org.apache.calcite.sql.SqlIdentifier;
3435
import org.apache.calcite.sql.SqlKind;
3536
import org.apache.calcite.sql.parser.SqlParserPos;
@@ -77,6 +78,14 @@ public class UserDefinedFunctionUtils {
7778
public static Set<String> MULTI_FIELDS_RELEVANCE_FUNCTION_SET =
7879
ImmutableSet.of("simple_query_string", "query_string", "multi_match");
7980

81+
/**
82+
* Creates a SqlUserDefinedAggFunction that wraps a Java class implementing an aggregate function.
83+
*
84+
* @param udafClass The Java class that implements the UserDefinedAggFunction interface
85+
* @param functionName The name of the function to be used in SQL statements
86+
* @param returnType A SqlReturnTypeInference that determines the return type of the function
87+
* @return A SqlUserDefinedAggFunction that can be used in SQL queries
88+
*/
8089
public static SqlUserDefinedAggFunction createUserDefinedAggFunction(
8190
Class<? extends UserDefinedAggFunction<?>> udafClass,
8291
String functionName,
@@ -93,17 +102,38 @@ public static SqlUserDefinedAggFunction createUserDefinedAggFunction(
93102
Optionality.FORBIDDEN);
94103
}
95104

96-
public static RelBuilder.AggCall convertUDAFToAggCall(
97-
SqlUserDefinedAggFunction udaf,
105+
/**
106+
* Creates an aggregate call using the provided SqlAggFunction and arguments.
107+
*
108+
* @param aggFunction The aggregate function to call
109+
* @param fields The primary fields to aggregate
110+
* @param argList Additional arguments for the aggregate function
111+
* @param relBuilder The RelBuilder instance used for building relational expressions
112+
* @return An AggCall object representing the aggregate function call
113+
*/
114+
public static RelBuilder.AggCall makeAggregateCall(
115+
SqlAggFunction aggFunction,
98116
List<RexNode> fields,
99117
List<RexNode> argList,
100118
RelBuilder relBuilder) {
101119
List<RexNode> addArgList = new ArrayList<>(fields);
102120
addArgList.addAll(argList);
103-
return relBuilder.aggregateCall(udaf, addArgList);
121+
return relBuilder.aggregateCall(aggFunction, addArgList);
104122
}
105123

106-
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
124+
/**
125+
* Creates and registers a User Defined Aggregate Function (UDAF) and returns an AggCall that can
126+
* be used in query plans.
127+
*
128+
* @param udafClass The class implementing the aggregate function behavior
129+
* @param functionName The name of the aggregate function
130+
* @param returnType The return type inference for determining the result type
131+
* @param fields The primary fields to aggregate
132+
* @param argList Additional arguments for the aggregate function
133+
* @param relBuilder The RelBuilder instance used for building relational expressions
134+
* @return An AggCall object representing the aggregate function call
135+
*/
136+
public static RelBuilder.AggCall createAggregateFunction(
107137
Class<? extends UserDefinedAggFunction<?>> udafClass,
108138
String functionName,
109139
SqlReturnTypeInference returnType,
@@ -112,7 +142,7 @@ public static RelBuilder.AggCall TransferUserDefinedAggFunction(
112142
RelBuilder relBuilder) {
113143
SqlUserDefinedAggFunction udaf =
114144
createUserDefinedAggFunction(udafClass, functionName, returnType);
115-
return convertUDAFToAggCall(udaf, fields, argList, relBuilder);
145+
return makeAggregateCall(udaf, fields, argList, relBuilder);
116146
}
117147

118148
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_SAMP_NULLABLE;
1313
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
1414
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName;
15-
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedAggFunction;
15+
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.createAggregateFunction;
1616
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ABS;
1717
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ACOS;
1818
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD;
@@ -238,6 +238,7 @@
238238
import org.apache.calcite.rex.RexBuilder;
239239
import org.apache.calcite.rex.RexLambda;
240240
import org.apache.calcite.rex.RexNode;
241+
import org.apache.calcite.sql.SqlAggFunction;
241242
import org.apache.calcite.sql.SqlOperator;
242243
import org.apache.calcite.sql.fun.SqlLibraryOperators;
243244
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
@@ -400,7 +401,7 @@ public void registerExternalAggOperator(
400401
CalciteFuncSignature signature = new CalciteFuncSignature(functionName.getName(), typeChecker);
401402
AggHandler handler =
402403
(distinct, field, argList, ctx) ->
403-
UserDefinedFunctionUtils.convertUDAFToAggCall(
404+
UserDefinedFunctionUtils.makeAggregateCall(
404405
aggFunction, List.of(field), argList, ctx.relBuilder);
405406
aggExternalFunctionRegistry.put(functionName, Pair.of(signature, handler));
406407
}
@@ -419,14 +420,27 @@ public RelBuilder.AggCall resolveAgg(
419420
throw new IllegalStateException(String.format("Cannot resolve function: %s", functionName));
420421
}
421422
CalciteFuncSignature signature = implementation.getKey();
422-
RelDataType fieldType = field.getType();
423-
if (!signature.match(functionName.getName(), List.of(fieldType))) {
423+
List<RelDataType> argTypes = new ArrayList<>();
424+
if (field != null) {
425+
argTypes.add(field.getType());
426+
}
427+
// Currently only PERCENTILE_APPROX and TAKE have additional arguments.
428+
// Their additional arguments will always come as a map of <argName, value>
429+
List<RelDataType> additionalArgTypes =
430+
argList.stream().map(PlanUtils::derefMapCall).map(RexNode::getType).toList();
431+
argTypes.addAll(additionalArgTypes);
432+
if (!signature.match(functionName.getName(), argTypes)) {
433+
String errorMessagePattern =
434+
argTypes.size() <= 1
435+
? "Aggregation function %s expects field type {%s}, but got %s"
436+
: "Aggregation function %s expects field type and additional arguments {%s}, but got"
437+
+ " %s";
424438
throw new ExpressionEvaluationException(
425439
String.format(
426-
"Aggregation function %s expects field type {%s}, but got %s",
440+
errorMessagePattern,
427441
functionName,
428442
signature.typeChecker().getAllowedSignatures(),
429-
getActualSignature(List.of(fieldType))));
443+
getActualSignature(argTypes)));
430444
}
431445
var handler = implementation.getValue();
432446
return handler.apply(distinct, field, argList, context);
@@ -1069,92 +1083,105 @@ void register(
10691083
map.put(functionName, Pair.of(signature, aggHandler));
10701084
}
10711085

1072-
void registerOperator(BuiltinFunctionName functionName, SqlUserDefinedAggFunction aggFunction) {
1086+
void registerOperator(BuiltinFunctionName functionName, SqlAggFunction aggFunction) {
10731087
PPLTypeChecker typeChecker =
10741088
wrapSqlOperandTypeChecker(aggFunction.getOperandTypeChecker(), functionName.name(), true);
10751089
AggHandler handler =
10761090
(distinct, field, argList, ctx) ->
1077-
UserDefinedFunctionUtils.convertUDAFToAggCall(
1091+
UserDefinedFunctionUtils.makeAggregateCall(
10781092
aggFunction, List.of(field), argList, ctx.relBuilder);
10791093
register(functionName, handler, typeChecker);
10801094
}
10811095

10821096
void populate() {
1083-
register(MAX, (distinct, field, argList, ctx) -> ctx.relBuilder.max(field), null);
1084-
register(MIN, (distinct, field, argList, ctx) -> ctx.relBuilder.min(field), null);
1097+
registerOperator(MAX, SqlStdOperatorTable.MAX);
1098+
registerOperator(MIN, SqlStdOperatorTable.MIN);
1099+
registerOperator(SUM, SqlStdOperatorTable.SUM);
10851100

10861101
register(
1087-
AVG, (distinct, field, argList, ctx) -> ctx.relBuilder.avg(distinct, null, field), null);
1102+
AVG,
1103+
(distinct, field, argList, ctx) -> ctx.relBuilder.avg(distinct, null, field),
1104+
wrapSqlOperandTypeChecker(
1105+
SqlStdOperatorTable.AVG.getOperandTypeChecker(), AVG.name(), false));
10881106

10891107
register(
10901108
COUNT,
10911109
(distinct, field, argList, ctx) ->
10921110
ctx.relBuilder.count(
10931111
distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field)),
1094-
null);
1095-
register(
1096-
SUM,
1097-
(distinct, field, argList, ctx) ->
1098-
ctx.relBuilder.aggregateCall(SqlStdOperatorTable.SUM, field),
1099-
null);
1112+
wrapSqlOperandTypeChecker(
1113+
SqlStdOperatorTable.COUNT.getOperandTypeChecker(), COUNT.name(), false));
11001114

11011115
register(
11021116
VARSAMP,
11031117
(distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field),
1104-
null);
1118+
wrapSqlOperandTypeChecker(
1119+
SqlStdOperatorTable.VAR_SAMP.getOperandTypeChecker(), VARSAMP.name(), false));
11051120

11061121
register(
11071122
VARPOP,
11081123
(distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_POP_NULLABLE, field),
1109-
null);
1124+
wrapSqlOperandTypeChecker(
1125+
SqlStdOperatorTable.VAR_POP.getOperandTypeChecker(), VARPOP.name(), false));
11101126

11111127
register(
11121128
STDDEV_SAMP,
11131129
(distinct, field, argList, ctx) ->
11141130
ctx.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field),
1115-
null);
1131+
wrapSqlOperandTypeChecker(
1132+
SqlStdOperatorTable.STDDEV_SAMP.getOperandTypeChecker(), STDDEV_SAMP.name(), false));
11161133

11171134
register(
11181135
STDDEV_POP,
11191136
(distinct, field, argList, ctx) ->
11201137
ctx.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field),
1121-
null);
1138+
wrapSqlOperandTypeChecker(
1139+
SqlStdOperatorTable.STDDEV_POP.getOperandTypeChecker(), STDDEV_POP.name(), false));
11221140

11231141
register(
11241142
TAKE,
11251143
(distinct, field, argList, ctx) -> {
11261144
List<RexNode> newArgList =
11271145
argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList());
1128-
return TransferUserDefinedAggFunction(
1146+
return createAggregateFunction(
11291147
TakeAggFunction.class,
11301148
"TAKE",
11311149
UserDefinedFunctionUtils.getReturnTypeInferenceForArray(),
11321150
List.of(field),
11331151
newArgList,
11341152
ctx.relBuilder);
11351153
},
1136-
null);
1154+
PPLTypeChecker.wrapComposite(
1155+
(CompositeOperandTypeChecker)
1156+
OperandTypes.ANY.or(
1157+
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)),
1158+
false));
11371159

11381160
register(
11391161
PERCENTILE_APPROX,
11401162
(distinct, field, argList, ctx) -> {
11411163
List<RexNode> newArgList =
11421164
argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList());
11431165
newArgList.add(ctx.rexBuilder.makeFlag(field.getType().getSqlTypeName()));
1144-
return TransferUserDefinedAggFunction(
1166+
return createAggregateFunction(
11451167
PercentileApproxFunction.class,
11461168
"percentile_approx",
11471169
ReturnTypes.ARG0_FORCE_NULLABLE,
11481170
List.of(field),
11491171
newArgList,
11501172
ctx.relBuilder);
11511173
},
1152-
null);
1174+
PPLTypeChecker.wrapComposite(
1175+
(CompositeOperandTypeChecker)
1176+
OperandTypes.NUMERIC_NUMERIC.or(
1177+
OperandTypes.family(
1178+
SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)),
1179+
false));
11531180

11541181
register(
11551182
INTERNAL_PATTERN,
11561183
(distinct, field, argList, ctx) ->
1157-
TransferUserDefinedAggFunction(
1184+
createAggregateFunction(
11581185
LogPatternAggFunction.class,
11591186
"pattern",
11601187
ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList),

0 commit comments

Comments
 (0)