Skip to content

Commit 7b34ba0

Browse files
committed
Refactor: simplify registration of user-defined aggregation functions
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent be38740 commit 7b34ba0

6 files changed

Lines changed: 94 additions & 146 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
import org.apache.calcite.runtime.Hook;
7878
import org.apache.calcite.schema.SchemaPlus;
7979
import org.apache.calcite.server.CalciteServerStatement;
80-
import org.apache.calcite.sql.SqlAggFunction;
8180
import org.apache.calcite.sql.SqlKind;
8281
import org.apache.calcite.sql.parser.SqlParserPos;
8382
import org.apache.calcite.sql2rel.SqlRexConvertletTable;
@@ -89,7 +88,7 @@
8988
import org.apache.calcite.util.Util;
9089
import org.opensearch.sql.calcite.CalcitePlanContext;
9190
import org.opensearch.sql.calcite.plan.Scannable;
92-
import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction;
91+
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
9392

9493
/**
9594
* Calcite Tools Helper. This class is used to create customized: 1. Connection 2. JavaTypeFactory
@@ -185,7 +184,7 @@ public OpenSearchRelBuilder(Context context, RelOptCluster cluster, RelOptSchema
185184
public AggCall avg(boolean distinct, String alias, RexNode operand) {
186185
return aggregateCall(
187186
SqlParserPos.ZERO,
188-
AVG_NULLABLE,
187+
PPLBuiltinOperators.AVG_NULLABLE,
189188
distinct,
190189
false,
191190
false,
@@ -198,16 +197,6 @@ public AggCall avg(boolean distinct, String alias, RexNode operand) {
198197
}
199198
}
200199

201-
public static final SqlAggFunction AVG_NULLABLE = new NullableSqlAvgAggFunction(SqlKind.AVG);
202-
public static final SqlAggFunction STDDEV_POP_NULLABLE =
203-
new NullableSqlAvgAggFunction(SqlKind.STDDEV_POP);
204-
public static final SqlAggFunction STDDEV_SAMP_NULLABLE =
205-
new NullableSqlAvgAggFunction(SqlKind.STDDEV_SAMP);
206-
public static final SqlAggFunction VAR_POP_NULLABLE =
207-
new NullableSqlAvgAggFunction(SqlKind.VAR_POP);
208-
public static final SqlAggFunction VAR_SAMP_NULLABLE =
209-
new NullableSqlAvgAggFunction(SqlKind.VAR_SAMP);
210-
211200
public static class OpenSearchPrepareImpl extends CalcitePrepareImpl {
212201
/**
213202
* Similar to {@link CalcitePrepareImpl#perform(CalciteServerStatement, FrameworkConfig,

core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ private PPLOperandTypes() {}
3939
OperandTypes.NUMERIC.or(
4040
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)));
4141

42+
public static final UDFOperandMetadata ANY_OPTIONAL_INTEGER =
43+
UDFOperandMetadata.wrap(
44+
(CompositeOperandTypeChecker)
45+
OperandTypes.ANY.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)));
4246
public static final UDFOperandMetadata INTEGER_INTEGER =
4347
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
4448
public static final UDFOperandMetadata STRING_STRING =
@@ -48,6 +52,12 @@ private PPLOperandTypes() {}
4852
public static final UDFOperandMetadata STRING_INTEGER =
4953
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER));
5054

55+
public static final UDFOperandMetadata NUMERIC_NUMERIC_OPTIONAL_NUMERIC =
56+
UDFOperandMetadata.wrap(
57+
(CompositeOperandTypeChecker)
58+
OperandTypes.NUMERIC_NUMERIC.or(
59+
OperandTypes.family(
60+
SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)));
5161
public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC =
5262
UDFOperandMetadata.wrap(
5363
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,14 @@ public class UserDefinedFunctionUtils {
9191
public static SqlUserDefinedAggFunction createUserDefinedAggFunction(
9292
Class<? extends UserDefinedAggFunction<?>> udafClass,
9393
String functionName,
94-
SqlReturnTypeInference returnType) {
94+
SqlReturnTypeInference returnType,
95+
@Nullable UDFOperandMetadata operandMetadata) {
9596
return new SqlUserDefinedAggFunction(
9697
new SqlIdentifier(functionName, SqlParserPos.ZERO),
9798
SqlKind.OTHER_FUNCTION,
9899
returnType,
99100
null,
100-
null,
101+
operandMetadata,
101102
AggregateFunctionImpl.create(udafClass),
102103
false,
103104
false,
@@ -123,31 +124,7 @@ public static RelBuilder.AggCall makeAggregateCall(
123124
return relBuilder.aggregateCall(aggFunction, addArgList);
124125
}
125126

126-
/**
127-
* Creates and registers a User Defined Aggregate Function (UDAF) and returns an AggCall that can
128-
* be used in query plans.
129-
*
130-
* @param udafClass The class implementing the aggregate function behavior
131-
* @param functionName The name of the aggregate function
132-
* @param returnType The return type inference for determining the result type
133-
* @param fields The primary fields to aggregate
134-
* @param argList Additional arguments for the aggregate function
135-
* @param relBuilder The RelBuilder instance used for building relational expressions
136-
* @return An AggCall object representing the aggregate function call
137-
*/
138-
public static RelBuilder.AggCall createAggregateFunction(
139-
Class<? extends UserDefinedAggFunction<?>> udafClass,
140-
String functionName,
141-
SqlReturnTypeInference returnType,
142-
List<RexNode> fields,
143-
List<RexNode> argList,
144-
RelBuilder relBuilder) {
145-
SqlUserDefinedAggFunction udaf =
146-
createUserDefinedAggFunction(udafClass, functionName, returnType);
147-
return makeAggregateCall(udaf, fields, argList, relBuilder);
148-
}
149-
150-
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {
127+
public static SqlReturnTypeInference createReturnTypeInferenceForArray() {
151128
return opBinding -> {
152129
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
153130

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptExprMethodToUDF;
99
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptExprMethodWithPropertiesToUDF;
1010
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptMathFunctionToUDF;
11+
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.createUserDefinedAggFunction;
1112

1213
import com.google.common.base.Suppliers;
1314
import java.lang.reflect.InvocationTargetException;
@@ -21,11 +22,17 @@
2122
import org.apache.calcite.avatica.util.TimeUnit;
2223
import org.apache.calcite.linq4j.tree.Expression;
2324
import org.apache.calcite.rex.RexCall;
25+
import org.apache.calcite.sql.SqlAggFunction;
26+
import org.apache.calcite.sql.SqlKind;
2427
import org.apache.calcite.sql.SqlOperator;
2528
import org.apache.calcite.sql.type.ReturnTypes;
2629
import org.apache.calcite.sql.type.SqlTypeTransforms;
2730
import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable;
2831
import org.apache.calcite.util.BuiltInMethod;
32+
import org.opensearch.sql.calcite.udf.udaf.LogPatternAggFunction;
33+
import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction;
34+
import org.opensearch.sql.calcite.udf.udaf.PercentileApproxFunction;
35+
import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction;
2936
import org.opensearch.sql.calcite.utils.PPLOperandTypes;
3037
import org.opensearch.sql.calcite.utils.PPLReturnTypes;
3138
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
@@ -381,6 +388,35 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
381388
public static final SqlOperator NUMBER_TO_STRING =
382389
new NumberToStringFunction().toUDF("NUMBER_TO_STRING");
383390

391+
// Aggregation functions
392+
public static final SqlAggFunction AVG_NULLABLE = new NullableSqlAvgAggFunction(SqlKind.AVG);
393+
public static final SqlAggFunction STDDEV_POP_NULLABLE =
394+
new NullableSqlAvgAggFunction(SqlKind.STDDEV_POP);
395+
public static final SqlAggFunction STDDEV_SAMP_NULLABLE =
396+
new NullableSqlAvgAggFunction(SqlKind.STDDEV_SAMP);
397+
public static final SqlAggFunction VAR_POP_NULLABLE =
398+
new NullableSqlAvgAggFunction(SqlKind.VAR_POP);
399+
public static final SqlAggFunction VAR_SAMP_NULLABLE =
400+
new NullableSqlAvgAggFunction(SqlKind.VAR_SAMP);
401+
public static final SqlAggFunction TAKE =
402+
createUserDefinedAggFunction(
403+
TakeAggFunction.class,
404+
"TAKE",
405+
UserDefinedFunctionUtils.createReturnTypeInferenceForArray(),
406+
PPLOperandTypes.ANY_OPTIONAL_INTEGER);
407+
public static final SqlAggFunction PERCENTILE_APPROX =
408+
createUserDefinedAggFunction(
409+
PercentileApproxFunction.class,
410+
"percentile_approx",
411+
ReturnTypes.ARG0_FORCE_NULLABLE,
412+
PPLOperandTypes.NUMERIC_NUMERIC_OPTIONAL_NUMERIC);
413+
public static final SqlAggFunction INTERNAL_PATTERN =
414+
createUserDefinedAggFunction(
415+
LogPatternAggFunction.class,
416+
"pattern",
417+
ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList),
418+
null);
419+
384420
/**
385421
* Returns the PPL specific operator table, creating it if necessary.
386422
*

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 40 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,8 @@
66
package org.opensearch.sql.expression.function;
77

88
import static org.apache.calcite.sql.SqlJsonConstructorNullClause.NULL_ON_NULL;
9-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_POP_NULLABLE;
10-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.STDDEV_SAMP_NULLABLE;
11-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_POP_NULLABLE;
12-
import static org.opensearch.sql.calcite.utils.CalciteToolsHelper.VAR_SAMP_NULLABLE;
139
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
1410
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName;
15-
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.createAggregateFunction;
1611
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ABS;
1712
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ACOS;
1813
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD;
@@ -246,7 +241,6 @@
246241
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
247242
import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker;
248243
import org.apache.calcite.sql.type.OperandTypes;
249-
import org.apache.calcite.sql.type.ReturnTypes;
250244
import org.apache.calcite.sql.type.SameOperandTypeChecker;
251245
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
252246
import org.apache.calcite.sql.type.SqlTypeFamily;
@@ -258,9 +252,6 @@
258252
import org.apache.logging.log4j.LogManager;
259253
import org.apache.logging.log4j.Logger;
260254
import org.opensearch.sql.calcite.CalcitePlanContext;
261-
import org.opensearch.sql.calcite.udf.udaf.LogPatternAggFunction;
262-
import org.opensearch.sql.calcite.udf.udaf.PercentileApproxFunction;
263-
import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction;
264255
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
265256
import org.opensearch.sql.calcite.utils.PlanUtils;
266257
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
@@ -631,6 +622,30 @@ private static PPLTypeChecker wrapSqlOperandTypeChecker(
631622
return pplTypeChecker;
632623
}
633624

625+
/**
626+
* Extracts the underlying {@link SqlOperandTypeChecker} from a {@link SqlOperator}.
627+
*
628+
* <p>For user-defined functions (UDFs) and user-defined aggregate functions (UDAFs), the {@link
629+
* SqlOperandTypeChecker} is typically wrapped in a {@link UDFOperandMetadata}, which contains the
630+
* actual type checker used for operand validation. Most of these wrapped type checkers are
631+
* defined in {@link org.opensearch.sql.calcite.utils.PPLOperandTypes}. This method retrieves the
632+
* inner type checker from {@link UDFOperandMetadata} if present.
633+
*
634+
* <p>For Calcite's built-in operators, its type checker is returned directly.
635+
*
636+
* @param operator the {@link SqlOperator}, which may be a Calcite built-in operator, a
637+
* user-defined function, or a user-defined aggregation function
638+
* @return the underlying {@link SqlOperandTypeChecker} instance, or {@code null} if not available
639+
*/
640+
private static SqlOperandTypeChecker extractTypeCheckerFromUDF(SqlOperator operator) {
641+
SqlOperandTypeChecker typeChecker = operator.getOperandTypeChecker();
642+
if (typeChecker instanceof UDFOperandMetadata) {
643+
UDFOperandMetadata udfOperandMetadata = (UDFOperandMetadata) typeChecker;
644+
return udfOperandMetadata.getInnerTypeChecker();
645+
}
646+
return typeChecker;
647+
}
648+
634649
@SuppressWarnings({"UnusedReturnValue", "SameParameterValue"})
635650
private abstract static class AbstractBuilder {
636651

@@ -652,13 +667,7 @@ abstract void register(
652667
*/
653668
public void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
654669
for (SqlOperator operator : operators) {
655-
SqlOperandTypeChecker typeChecker;
656-
if (operator instanceof SqlUserDefinedFunction udfOperator) {
657-
typeChecker = extractTypeCheckerFromUDF(udfOperator);
658-
} else {
659-
typeChecker = operator.getOperandTypeChecker();
660-
}
661-
670+
SqlOperandTypeChecker typeChecker = extractTypeCheckerFromUDF(operator);
662671
PPLTypeChecker pplTypeChecker =
663672
wrapSqlOperandTypeChecker(
664673
typeChecker, operator.getName(), operator instanceof SqlUserDefinedFunction);
@@ -669,13 +678,6 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op
669678
}
670679
}
671680

672-
private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
673-
SqlUserDefinedFunction udfOperator) {
674-
UDFOperandMetadata udfOperandMetadata =
675-
(UDFOperandMetadata) udfOperator.getOperandTypeChecker();
676-
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
677-
}
678-
679681
void populate() {
680682
// register operators for comparison
681683
registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP, SqlStdOperatorTable.NOT_EQUALS);
@@ -1094,19 +1096,30 @@ void register(
10941096
}
10951097

10961098
void registerOperator(BuiltinFunctionName functionName, SqlAggFunction aggFunction) {
1099+
SqlOperandTypeChecker innerTypeChecker = extractTypeCheckerFromUDF(aggFunction);
10971100
PPLTypeChecker typeChecker =
1098-
wrapSqlOperandTypeChecker(aggFunction.getOperandTypeChecker(), functionName.name(), true);
1101+
wrapSqlOperandTypeChecker(innerTypeChecker, functionName.name(), true);
10991102
AggHandler handler =
1100-
(distinct, field, argList, ctx) ->
1101-
UserDefinedFunctionUtils.makeAggregateCall(
1102-
aggFunction, List.of(field), argList, ctx.relBuilder);
1103+
(distinct, field, argList, ctx) -> {
1104+
List<RexNode> newArgList =
1105+
argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList());
1106+
return UserDefinedFunctionUtils.makeAggregateCall(
1107+
aggFunction, List.of(field), newArgList, ctx.relBuilder);
1108+
};
11031109
register(functionName, handler, typeChecker);
11041110
}
11051111

11061112
void populate() {
11071113
registerOperator(MAX, SqlStdOperatorTable.MAX);
11081114
registerOperator(MIN, SqlStdOperatorTable.MIN);
11091115
registerOperator(SUM, SqlStdOperatorTable.SUM);
1116+
registerOperator(VARSAMP, PPLBuiltinOperators.VAR_SAMP_NULLABLE);
1117+
registerOperator(VARPOP, PPLBuiltinOperators.VAR_POP_NULLABLE);
1118+
registerOperator(STDDEV_SAMP, PPLBuiltinOperators.STDDEV_SAMP_NULLABLE);
1119+
registerOperator(STDDEV_POP, PPLBuiltinOperators.STDDEV_POP_NULLABLE);
1120+
registerOperator(TAKE, PPLBuiltinOperators.TAKE);
1121+
registerOperator(PERCENTILE_APPROX, PPLBuiltinOperators.PERCENTILE_APPROX);
1122+
registerOperator(INTERNAL_PATTERN, PPLBuiltinOperators.INTERNAL_PATTERN);
11101123

11111124
register(
11121125
AVG,
@@ -1121,84 +1134,6 @@ void populate() {
11211134
distinct, null, field == null ? ImmutableList.of() : ImmutableList.of(field)),
11221135
wrapSqlOperandTypeChecker(
11231136
SqlStdOperatorTable.COUNT.getOperandTypeChecker(), COUNT.name(), false));
1124-
1125-
register(
1126-
VARSAMP,
1127-
(distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_SAMP_NULLABLE, field),
1128-
wrapSqlOperandTypeChecker(
1129-
SqlStdOperatorTable.VAR_SAMP.getOperandTypeChecker(), VARSAMP.name(), false));
1130-
1131-
register(
1132-
VARPOP,
1133-
(distinct, field, argList, ctx) -> ctx.relBuilder.aggregateCall(VAR_POP_NULLABLE, field),
1134-
wrapSqlOperandTypeChecker(
1135-
SqlStdOperatorTable.VAR_POP.getOperandTypeChecker(), VARPOP.name(), false));
1136-
1137-
register(
1138-
STDDEV_SAMP,
1139-
(distinct, field, argList, ctx) ->
1140-
ctx.relBuilder.aggregateCall(STDDEV_SAMP_NULLABLE, field),
1141-
wrapSqlOperandTypeChecker(
1142-
SqlStdOperatorTable.STDDEV_SAMP.getOperandTypeChecker(), STDDEV_SAMP.name(), false));
1143-
1144-
register(
1145-
STDDEV_POP,
1146-
(distinct, field, argList, ctx) ->
1147-
ctx.relBuilder.aggregateCall(STDDEV_POP_NULLABLE, field),
1148-
wrapSqlOperandTypeChecker(
1149-
SqlStdOperatorTable.STDDEV_POP.getOperandTypeChecker(), STDDEV_POP.name(), false));
1150-
1151-
register(
1152-
TAKE,
1153-
(distinct, field, argList, ctx) -> {
1154-
List<RexNode> newArgList =
1155-
argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList());
1156-
return createAggregateFunction(
1157-
TakeAggFunction.class,
1158-
"TAKE",
1159-
UserDefinedFunctionUtils.getReturnTypeInferenceForArray(),
1160-
List.of(field),
1161-
newArgList,
1162-
ctx.relBuilder);
1163-
},
1164-
PPLTypeChecker.wrapComposite(
1165-
(CompositeOperandTypeChecker)
1166-
OperandTypes.ANY.or(
1167-
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)),
1168-
false));
1169-
1170-
register(
1171-
PERCENTILE_APPROX,
1172-
(distinct, field, argList, ctx) -> {
1173-
List<RexNode> newArgList =
1174-
argList.stream().map(PlanUtils::derefMapCall).collect(Collectors.toList());
1175-
newArgList.add(ctx.rexBuilder.makeFlag(field.getType().getSqlTypeName()));
1176-
return createAggregateFunction(
1177-
PercentileApproxFunction.class,
1178-
"percentile_approx",
1179-
ReturnTypes.ARG0_FORCE_NULLABLE,
1180-
List.of(field),
1181-
newArgList,
1182-
ctx.relBuilder);
1183-
},
1184-
PPLTypeChecker.wrapComposite(
1185-
(CompositeOperandTypeChecker)
1186-
OperandTypes.NUMERIC_NUMERIC.or(
1187-
OperandTypes.family(
1188-
SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)),
1189-
false));
1190-
1191-
register(
1192-
INTERNAL_PATTERN,
1193-
(distinct, field, argList, ctx) ->
1194-
createAggregateFunction(
1195-
LogPatternAggFunction.class,
1196-
"pattern",
1197-
ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList),
1198-
List.of(field),
1199-
argList,
1200-
ctx.relBuilder),
1201-
null);
12021137
}
12031138
}
12041139
}

0 commit comments

Comments
 (0)