Skip to content

Commit f820c4f

Browse files
committed
Refactor: simplify registration of user-defined aggregation functions
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent be38740 commit f820c4f

7 files changed

Lines changed: 120 additions & 161 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
import org.apache.calcite.runtime.Hook;
7878
import org.apache.calcite.schema.SchemaPlus;
7979
import org.apache.calcite.server.CalciteServerStatement;
80-
import org.apache.calcite.sql.SqlAggFunction;
8180
import org.apache.calcite.sql.SqlKind;
8281
import org.apache.calcite.sql.parser.SqlParserPos;
8382
import org.apache.calcite.sql2rel.SqlRexConvertletTable;
@@ -89,7 +88,7 @@
8988
import org.apache.calcite.util.Util;
9089
import org.opensearch.sql.calcite.CalcitePlanContext;
9190
import org.opensearch.sql.calcite.plan.Scannable;
92-
import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction;
91+
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
9392

9493
/**
9594
* Calcite Tools Helper. This class is used to create customized: 1. Connection 2. JavaTypeFactory
@@ -185,7 +184,7 @@ public OpenSearchRelBuilder(Context context, RelOptCluster cluster, RelOptSchema
185184
public AggCall avg(boolean distinct, String alias, RexNode operand) {
186185
return aggregateCall(
187186
SqlParserPos.ZERO,
188-
AVG_NULLABLE,
187+
PPLBuiltinOperators.AVG_NULLABLE,
189188
distinct,
190189
false,
191190
false,
@@ -198,16 +197,6 @@ public AggCall avg(boolean distinct, String alias, RexNode operand) {
198197
}
199198
}
200199

201-
public static final SqlAggFunction AVG_NULLABLE = new NullableSqlAvgAggFunction(SqlKind.AVG);
202-
public static final SqlAggFunction STDDEV_POP_NULLABLE =
203-
new NullableSqlAvgAggFunction(SqlKind.STDDEV_POP);
204-
public static final SqlAggFunction STDDEV_SAMP_NULLABLE =
205-
new NullableSqlAvgAggFunction(SqlKind.STDDEV_SAMP);
206-
public static final SqlAggFunction VAR_POP_NULLABLE =
207-
new NullableSqlAvgAggFunction(SqlKind.VAR_POP);
208-
public static final SqlAggFunction VAR_SAMP_NULLABLE =
209-
new NullableSqlAvgAggFunction(SqlKind.VAR_SAMP);
210-
211200
public static class OpenSearchPrepareImpl extends CalcitePrepareImpl {
212201
/**
213202
* Similar to {@link CalcitePrepareImpl#perform(CalciteServerStatement, FrameworkConfig,

core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ private PPLOperandTypes() {}
3939
OperandTypes.NUMERIC.or(
4040
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)));
4141

42+
public static final UDFOperandMetadata ANY_OPTIONAL_INTEGER =
43+
UDFOperandMetadata.wrap(
44+
(CompositeOperandTypeChecker)
45+
OperandTypes.ANY.or(OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER)));
4246
public static final UDFOperandMetadata INTEGER_INTEGER =
4347
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
4448
public static final UDFOperandMetadata STRING_STRING =
@@ -48,6 +52,12 @@ private PPLOperandTypes() {}
4852
public static final UDFOperandMetadata STRING_INTEGER =
4953
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER));
5054

55+
public static final UDFOperandMetadata NUMERIC_NUMERIC_OPTIONAL_NUMERIC =
56+
UDFOperandMetadata.wrap(
57+
(CompositeOperandTypeChecker)
58+
OperandTypes.NUMERIC_NUMERIC.or(
59+
OperandTypes.family(
60+
SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)));
5161
public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC =
5262
UDFOperandMetadata.wrap(
5363
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));

core/src/main/java/org/opensearch/sql/calcite/utils/PPLReturnTypes.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
package org.opensearch.sql.calcite.utils;
77

8+
import java.util.List;
89
import org.apache.calcite.rel.type.RelDataType;
10+
import org.apache.calcite.rel.type.RelDataTypeFactory;
911
import org.apache.calcite.sql.type.ReturnTypes;
1012
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1113
import org.apache.calcite.sql.type.SqlTypeTransforms;
14+
import org.apache.calcite.sql.type.SqlTypeUtil;
1215
import org.opensearch.sql.data.type.ExprCoreType;
1316

1417
/**
@@ -39,4 +42,17 @@ private PPLReturnTypes() {}
3942
}
4043
return UserDefinedFunctionUtils.NULLABLE_TIMESTAMP_UDT;
4144
};
45+
public static SqlReturnTypeInference ARG0_ARRAY =
46+
opBinding -> {
47+
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
48+
49+
// Get argument types
50+
List<RelDataType> argTypes = opBinding.collectOperandTypes();
51+
52+
if (argTypes.isEmpty()) {
53+
throw new IllegalArgumentException("Function requires at least one argument.");
54+
}
55+
RelDataType firstArgType = argTypes.getFirst();
56+
return SqlTypeUtil.createArrayType(typeFactory, firstArgType, true);
57+
};
4258
}

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.util.ArrayList;
1818
import java.util.Collections;
1919
import java.util.List;
20+
import java.util.Objects;
2021
import java.util.Set;
2122
import javax.annotation.Nullable;
2223
import org.apache.calcite.DataContext;
@@ -26,7 +27,6 @@
2627
import org.apache.calcite.linq4j.tree.Expression;
2728
import org.apache.calcite.linq4j.tree.Expressions;
2829
import org.apache.calcite.rel.type.RelDataType;
29-
import org.apache.calcite.rel.type.RelDataTypeFactory;
3030
import org.apache.calcite.rex.RexCall;
3131
import org.apache.calcite.rex.RexNode;
3232
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
@@ -91,14 +91,15 @@ public class UserDefinedFunctionUtils {
9191
public static SqlUserDefinedAggFunction createUserDefinedAggFunction(
9292
Class<? extends UserDefinedAggFunction<?>> udafClass,
9393
String functionName,
94-
SqlReturnTypeInference returnType) {
94+
SqlReturnTypeInference returnType,
95+
@Nullable UDFOperandMetadata operandMetadata) {
9596
return new SqlUserDefinedAggFunction(
9697
new SqlIdentifier(functionName, SqlParserPos.ZERO),
9798
SqlKind.OTHER_FUNCTION,
9899
returnType,
99100
null,
100-
null,
101-
AggregateFunctionImpl.create(udafClass),
101+
operandMetadata,
102+
Objects.requireNonNull(AggregateFunctionImpl.create(udafClass)),
102103
false,
103104
false,
104105
Optionality.FORBIDDEN);
@@ -123,45 +124,6 @@ public static RelBuilder.AggCall makeAggregateCall(
123124
return relBuilder.aggregateCall(aggFunction, addArgList);
124125
}
125126

126-
/**
127-
* Creates and registers a User Defined Aggregate Function (UDAF) and returns an AggCall that can
128-
* be used in query plans.
129-
*
130-
* @param udafClass The class implementing the aggregate function behavior
131-
* @param functionName The name of the aggregate function
132-
* @param returnType The return type inference for determining the result type
133-
* @param fields The primary fields to aggregate
134-
* @param argList Additional arguments for the aggregate function
135-
* @param relBuilder The RelBuilder instance used for building relational expressions
136-
* @return An AggCall object representing the aggregate function call
137-
*/
138-
public static RelBuilder.AggCall createAggregateFunction(
139-
Class<? extends UserDefinedAggFunction<?>> udafClass,
140-
String functionName,
141-
SqlReturnTypeInference returnType,
142-
List<RexNode> fields,
143-
List<RexNode> argList,
144-
RelBuilder relBuilder) {
145-
SqlUserDefinedAggFunction udaf =
146-
createUserDefinedAggFunction(udafClass, functionName, returnType);
147-
return makeAggregateCall(udaf, fields, argList, relBuilder);
148-
}
149-
150-
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {
151-
return opBinding -> {
152-
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
153-
154-
// Get argument types
155-
List<RelDataType> argTypes = opBinding.collectOperandTypes();
156-
157-
if (argTypes.isEmpty()) {
158-
throw new IllegalArgumentException("Function requires at least one argument.");
159-
}
160-
RelDataType firstArgType = argTypes.getFirst();
161-
return createArrayType(typeFactory, firstArgType, true);
162-
};
163-
}
164-
165127
public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) {
166128
if (type instanceof AbstractExprRelDataType<?> exprType) {
167129
return switch (exprType.getUdt()) {

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptExprMethodToUDF;
99
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptExprMethodWithPropertiesToUDF;
1010
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.adaptMathFunctionToUDF;
11+
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.createUserDefinedAggFunction;
1112

1213
import com.google.common.base.Suppliers;
1314
import java.lang.reflect.InvocationTargetException;
@@ -21,11 +22,17 @@
2122
import org.apache.calcite.avatica.util.TimeUnit;
2223
import org.apache.calcite.linq4j.tree.Expression;
2324
import org.apache.calcite.rex.RexCall;
25+
import org.apache.calcite.sql.SqlAggFunction;
26+
import org.apache.calcite.sql.SqlKind;
2427
import org.apache.calcite.sql.SqlOperator;
2528
import org.apache.calcite.sql.type.ReturnTypes;
2629
import org.apache.calcite.sql.type.SqlTypeTransforms;
2730
import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable;
2831
import org.apache.calcite.util.BuiltInMethod;
32+
import org.opensearch.sql.calcite.udf.udaf.LogPatternAggFunction;
33+
import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction;
34+
import org.opensearch.sql.calcite.udf.udaf.PercentileApproxFunction;
35+
import org.opensearch.sql.calcite.udf.udaf.TakeAggFunction;
2936
import org.opensearch.sql.calcite.utils.PPLOperandTypes;
3037
import org.opensearch.sql.calcite.utils.PPLReturnTypes;
3138
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
@@ -381,6 +388,35 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
381388
public static final SqlOperator NUMBER_TO_STRING =
382389
new NumberToStringFunction().toUDF("NUMBER_TO_STRING");
383390

391+
// Aggregation functions
392+
public static final SqlAggFunction AVG_NULLABLE = new NullableSqlAvgAggFunction(SqlKind.AVG);
393+
public static final SqlAggFunction STDDEV_POP_NULLABLE =
394+
new NullableSqlAvgAggFunction(SqlKind.STDDEV_POP);
395+
public static final SqlAggFunction STDDEV_SAMP_NULLABLE =
396+
new NullableSqlAvgAggFunction(SqlKind.STDDEV_SAMP);
397+
public static final SqlAggFunction VAR_POP_NULLABLE =
398+
new NullableSqlAvgAggFunction(SqlKind.VAR_POP);
399+
public static final SqlAggFunction VAR_SAMP_NULLABLE =
400+
new NullableSqlAvgAggFunction(SqlKind.VAR_SAMP);
401+
public static final SqlAggFunction TAKE =
402+
createUserDefinedAggFunction(
403+
TakeAggFunction.class,
404+
"TAKE",
405+
PPLReturnTypes.ARG0_ARRAY,
406+
PPLOperandTypes.ANY_OPTIONAL_INTEGER);
407+
public static final SqlAggFunction PERCENTILE_APPROX =
408+
createUserDefinedAggFunction(
409+
PercentileApproxFunction.class,
410+
"percentile_approx",
411+
ReturnTypes.ARG0_FORCE_NULLABLE,
412+
PPLOperandTypes.NUMERIC_NUMERIC_OPTIONAL_NUMERIC);
413+
public static final SqlAggFunction INTERNAL_PATTERN =
414+
createUserDefinedAggFunction(
415+
LogPatternAggFunction.class,
416+
"pattern",
417+
ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList),
418+
null);
419+
384420
/**
385421
* Returns the PPL specific operator table, creating it if necessary.
386422
*

0 commit comments

Comments
 (0)