Skip to content

Commit 54f1287

Browse files
authored
Implement type checking for aggregation functions with Calcite (opensearch-project#4024) (opensearch-project#4054)
* Remove getTypeChecker from FunctionImp interface * Refactor registerExternalFunction to registerExternalOperator * Do not register GEOIP function if got incompatible client * Create scaffold for type checking of aggregation functions * Add type checkers for aggregation functions * Test type checking for aggregation functions --------- (cherry picked from commit d758163) Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent d4101e9 commit 54f1287

5 files changed

Lines changed: 404 additions & 318 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/udf/udaf/TakeAggFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public Object result(TakeAccumulator accumulator) {
2424
@Override
2525
public TakeAccumulator add(TakeAccumulator acc, Object... values) {
2626
Object candidateValue = values[0];
27-
int size = 0;
27+
int size;
2828
if (values.length > 1) {
2929
size = (int) values[1];
3030
} else {

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.calcite.rex.RexCall;
3333
import org.apache.calcite.rex.RexNode;
3434
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
35+
import org.apache.calcite.sql.SqlAggFunction;
3536
import org.apache.calcite.sql.SqlIdentifier;
3637
import org.apache.calcite.sql.SqlKind;
3738
import org.apache.calcite.sql.parser.SqlParserPos;
@@ -79,27 +80,71 @@ public class UserDefinedFunctionUtils {
7980
public static Set<String> MULTI_FIELDS_RELEVANCE_FUNCTION_SET =
8081
ImmutableSet.of("simple_query_string", "query_string", "multi_match");
8182

82-
public static RelBuilder.AggCall TransferUserDefinedAggFunction(
83-
Class<? extends UserDefinedAggFunction> UDAF,
83+
/**
84+
* Creates a SqlUserDefinedAggFunction that wraps a Java class implementing an aggregate function.
85+
*
86+
* @param udafClass The Java class that implements the UserDefinedAggFunction interface
87+
* @param functionName The name of the function to be used in SQL statements
88+
* @param returnType A SqlReturnTypeInference that determines the return type of the function
89+
* @return A SqlUserDefinedAggFunction that can be used in SQL queries
90+
*/
91+
public static SqlUserDefinedAggFunction createUserDefinedAggFunction(
92+
Class<? extends UserDefinedAggFunction<?>> udafClass,
8493
String functionName,
85-
SqlReturnTypeInference returnType,
94+
SqlReturnTypeInference returnType) {
95+
return new SqlUserDefinedAggFunction(
96+
new SqlIdentifier(functionName, SqlParserPos.ZERO),
97+
SqlKind.OTHER_FUNCTION,
98+
returnType,
99+
null,
100+
null,
101+
AggregateFunctionImpl.create(udafClass),
102+
false,
103+
false,
104+
Optionality.FORBIDDEN);
105+
}
106+
107+
/**
108+
* Creates an aggregate call using the provided SqlAggFunction and arguments.
109+
*
110+
* @param aggFunction The aggregate function to call
111+
* @param fields The primary fields to aggregate
112+
* @param argList Additional arguments for the aggregate function
113+
* @param relBuilder The RelBuilder instance used for building relational expressions
114+
* @return An AggCall object representing the aggregate function call
115+
*/
116+
public static RelBuilder.AggCall makeAggregateCall(
117+
SqlAggFunction aggFunction,
86118
List<RexNode> fields,
87119
List<RexNode> argList,
88120
RelBuilder relBuilder) {
89-
SqlUserDefinedAggFunction sqlUDAF =
90-
new SqlUserDefinedAggFunction(
91-
new SqlIdentifier(functionName, SqlParserPos.ZERO),
92-
SqlKind.OTHER_FUNCTION,
93-
returnType,
94-
null,
95-
null,
96-
AggregateFunctionImpl.create(UDAF),
97-
false,
98-
false,
99-
Optionality.FORBIDDEN);
100121
List<RexNode> addArgList = new ArrayList<>(fields);
101122
addArgList.addAll(argList);
102-
return relBuilder.aggregateCall(sqlUDAF, addArgList);
123+
return relBuilder.aggregateCall(aggFunction, addArgList);
124+
}
125+
126+
/**
127+
* Creates and registers a User Defined Aggregate Function (UDAF) and returns an AggCall that can
128+
* be used in query plans.
129+
*
130+
* @param udafClass The class implementing the aggregate function behavior
131+
* @param functionName The name of the aggregate function
132+
* @param returnType The return type inference for determining the result type
133+
* @param fields The primary fields to aggregate
134+
* @param argList Additional arguments for the aggregate function
135+
* @param relBuilder The RelBuilder instance used for building relational expressions
136+
* @return An AggCall object representing the aggregate function call
137+
*/
138+
public static RelBuilder.AggCall createAggregateFunction(
139+
Class<? extends UserDefinedAggFunction<?>> udafClass,
140+
String functionName,
141+
SqlReturnTypeInference returnType,
142+
List<RexNode> fields,
143+
List<RexNode> argList,
144+
RelBuilder relBuilder) {
145+
SqlUserDefinedAggFunction udaf =
146+
createUserDefinedAggFunction(udafClass, functionName, returnType);
147+
return makeAggregateCall(udaf, fields, argList, relBuilder);
103148
}
104149

105150
public static SqlReturnTypeInference getReturnTypeInferenceForArray() {

0 commit comments

Comments
 (0)