1212import static org .opensearch .sql .calcite .utils .CalciteToolsHelper .VAR_SAMP_NULLABLE ;
1313import static org .opensearch .sql .calcite .utils .OpenSearchTypeFactory .TYPE_FACTORY ;
1414import static org .opensearch .sql .calcite .utils .OpenSearchTypeFactory .getLegacyTypeName ;
15- import static org .opensearch .sql .calcite .utils .UserDefinedFunctionUtils .TransferUserDefinedAggFunction ;
15+ import static org .opensearch .sql .calcite .utils .UserDefinedFunctionUtils .createAggregateFunction ;
1616import static org .opensearch .sql .expression .function .BuiltinFunctionName .ABS ;
1717import static org .opensearch .sql .expression .function .BuiltinFunctionName .ACOS ;
1818import static org .opensearch .sql .expression .function .BuiltinFunctionName .ADD ;
238238import org .apache .calcite .rex .RexBuilder ;
239239import org .apache .calcite .rex .RexLambda ;
240240import org .apache .calcite .rex .RexNode ;
241+ import org .apache .calcite .sql .SqlAggFunction ;
241242import org .apache .calcite .sql .SqlOperator ;
242243import org .apache .calcite .sql .fun .SqlLibraryOperators ;
243244import org .apache .calcite .sql .fun .SqlStdOperatorTable ;
@@ -400,7 +401,7 @@ public void registerExternalAggOperator(
400401 CalciteFuncSignature signature = new CalciteFuncSignature (functionName .getName (), typeChecker );
401402 AggHandler handler =
402403 (distinct , field , argList , ctx ) ->
403- UserDefinedFunctionUtils .convertUDAFToAggCall (
404+ UserDefinedFunctionUtils .makeAggregateCall (
404405 aggFunction , List .of (field ), argList , ctx .relBuilder );
405406 aggExternalFunctionRegistry .put (functionName , Pair .of (signature , handler ));
406407 }
@@ -419,14 +420,27 @@ public RelBuilder.AggCall resolveAgg(
419420 throw new IllegalStateException (String .format ("Cannot resolve function: %s" , functionName ));
420421 }
421422 CalciteFuncSignature signature = implementation .getKey ();
422- RelDataType fieldType = field .getType ();
423- if (!signature .match (functionName .getName (), List .of (fieldType ))) {
423+ List <RelDataType > argTypes = new ArrayList <>();
424+ if (field != null ) {
425+ argTypes .add (field .getType ());
426+ }
427+ // Currently only PERCENTILE_APPROX and TAKE have additional arguments.
428+ // Their additional arguments will always come as a map of <argName, value>
429+ List <RelDataType > additionalArgTypes =
430+ argList .stream ().map (PlanUtils ::derefMapCall ).map (RexNode ::getType ).toList ();
431+ argTypes .addAll (additionalArgTypes );
432+ if (!signature .match (functionName .getName (), argTypes )) {
433+ String errorMessagePattern =
434+ argTypes .size () <= 1
435+ ? "Aggregation function %s expects field type {%s}, but got %s"
436+ : "Aggregation function %s expects field type and additional arguments {%s}, but got"
437+ + " %s" ;
424438 throw new ExpressionEvaluationException (
425439 String .format (
426- "Aggregation function %s expects field type {%s}, but got %s" ,
440+ errorMessagePattern ,
427441 functionName ,
428442 signature .typeChecker ().getAllowedSignatures (),
429- getActualSignature (List . of ( fieldType ) )));
443+ getActualSignature (argTypes )));
430444 }
431445 var handler = implementation .getValue ();
432446 return handler .apply (distinct , field , argList , context );
@@ -1069,92 +1083,105 @@ void register(
10691083 map .put (functionName , Pair .of (signature , aggHandler ));
10701084 }
10711085
1072- void registerOperator (BuiltinFunctionName functionName , SqlUserDefinedAggFunction aggFunction ) {
1086+ void registerOperator (BuiltinFunctionName functionName , SqlAggFunction aggFunction ) {
10731087 PPLTypeChecker typeChecker =
10741088 wrapSqlOperandTypeChecker (aggFunction .getOperandTypeChecker (), functionName .name (), true );
10751089 AggHandler handler =
10761090 (distinct , field , argList , ctx ) ->
1077- UserDefinedFunctionUtils .convertUDAFToAggCall (
1091+ UserDefinedFunctionUtils .makeAggregateCall (
10781092 aggFunction , List .of (field ), argList , ctx .relBuilder );
10791093 register (functionName , handler , typeChecker );
10801094 }
10811095
10821096 void populate () {
1083- register (MAX , (distinct , field , argList , ctx ) -> ctx .relBuilder .max (field ), null );
1084- register (MIN , (distinct , field , argList , ctx ) -> ctx .relBuilder .min (field ), null );
1097+ registerOperator (MAX , SqlStdOperatorTable .MAX );
1098+ registerOperator (MIN , SqlStdOperatorTable .MIN );
1099+ registerOperator (SUM , SqlStdOperatorTable .SUM );
10851100
10861101 register (
1087- AVG , (distinct , field , argList , ctx ) -> ctx .relBuilder .avg (distinct , null , field ), null );
1102+ AVG ,
1103+ (distinct , field , argList , ctx ) -> ctx .relBuilder .avg (distinct , null , field ),
1104+ wrapSqlOperandTypeChecker (
1105+ SqlStdOperatorTable .AVG .getOperandTypeChecker (), AVG .name (), false ));
10881106
10891107 register (
10901108 COUNT ,
10911109 (distinct , field , argList , ctx ) ->
10921110 ctx .relBuilder .count (
10931111 distinct , null , field == null ? ImmutableList .of () : ImmutableList .of (field )),
1094- null );
1095- register (
1096- SUM ,
1097- (distinct , field , argList , ctx ) ->
1098- ctx .relBuilder .aggregateCall (SqlStdOperatorTable .SUM , field ),
1099- null );
1112+ wrapSqlOperandTypeChecker (
1113+ SqlStdOperatorTable .COUNT .getOperandTypeChecker (), COUNT .name (), false ));
11001114
11011115 register (
11021116 VARSAMP ,
11031117 (distinct , field , argList , ctx ) -> ctx .relBuilder .aggregateCall (VAR_SAMP_NULLABLE , field ),
1104- null );
1118+ wrapSqlOperandTypeChecker (
1119+ SqlStdOperatorTable .VAR_SAMP .getOperandTypeChecker (), VARSAMP .name (), false ));
11051120
11061121 register (
11071122 VARPOP ,
11081123 (distinct , field , argList , ctx ) -> ctx .relBuilder .aggregateCall (VAR_POP_NULLABLE , field ),
1109- null );
1124+ wrapSqlOperandTypeChecker (
1125+ SqlStdOperatorTable .VAR_POP .getOperandTypeChecker (), VARPOP .name (), false ));
11101126
11111127 register (
11121128 STDDEV_SAMP ,
11131129 (distinct , field , argList , ctx ) ->
11141130 ctx .relBuilder .aggregateCall (STDDEV_SAMP_NULLABLE , field ),
1115- null );
1131+ wrapSqlOperandTypeChecker (
1132+ SqlStdOperatorTable .STDDEV_SAMP .getOperandTypeChecker (), STDDEV_SAMP .name (), false ));
11161133
11171134 register (
11181135 STDDEV_POP ,
11191136 (distinct , field , argList , ctx ) ->
11201137 ctx .relBuilder .aggregateCall (STDDEV_POP_NULLABLE , field ),
1121- null );
1138+ wrapSqlOperandTypeChecker (
1139+ SqlStdOperatorTable .STDDEV_POP .getOperandTypeChecker (), STDDEV_POP .name (), false ));
11221140
11231141 register (
11241142 TAKE ,
11251143 (distinct , field , argList , ctx ) -> {
11261144 List <RexNode > newArgList =
11271145 argList .stream ().map (PlanUtils ::derefMapCall ).collect (Collectors .toList ());
1128- return TransferUserDefinedAggFunction (
1146+ return createAggregateFunction (
11291147 TakeAggFunction .class ,
11301148 "TAKE" ,
11311149 UserDefinedFunctionUtils .getReturnTypeInferenceForArray (),
11321150 List .of (field ),
11331151 newArgList ,
11341152 ctx .relBuilder );
11351153 },
1136- null );
1154+ PPLTypeChecker .wrapComposite (
1155+ (CompositeOperandTypeChecker )
1156+ OperandTypes .ANY .or (
1157+ OperandTypes .family (SqlTypeFamily .ANY , SqlTypeFamily .INTEGER )),
1158+ false ));
11371159
11381160 register (
11391161 PERCENTILE_APPROX ,
11401162 (distinct , field , argList , ctx ) -> {
11411163 List <RexNode > newArgList =
11421164 argList .stream ().map (PlanUtils ::derefMapCall ).collect (Collectors .toList ());
11431165 newArgList .add (ctx .rexBuilder .makeFlag (field .getType ().getSqlTypeName ()));
1144- return TransferUserDefinedAggFunction (
1166+ return createAggregateFunction (
11451167 PercentileApproxFunction .class ,
11461168 "percentile_approx" ,
11471169 ReturnTypes .ARG0_FORCE_NULLABLE ,
11481170 List .of (field ),
11491171 newArgList ,
11501172 ctx .relBuilder );
11511173 },
1152- null );
1174+ PPLTypeChecker .wrapComposite (
1175+ (CompositeOperandTypeChecker )
1176+ OperandTypes .NUMERIC_NUMERIC .or (
1177+ OperandTypes .family (
1178+ SqlTypeFamily .NUMERIC , SqlTypeFamily .NUMERIC , SqlTypeFamily .NUMERIC )),
1179+ false ));
11531180
11541181 register (
11551182 INTERNAL_PATTERN ,
11561183 (distinct , field , argList , ctx ) ->
1157- TransferUserDefinedAggFunction (
1184+ createAggregateFunction (
11581185 LogPatternAggFunction .class ,
11591186 "pattern" ,
11601187 ReturnTypes .explicit (UserDefinedFunctionUtils .nullablePatternAggList ),
0 commit comments