250250import org .apache .calcite .sql .type .SqlOperandTypeChecker ;
251251import org .apache .calcite .sql .type .SqlTypeFamily ;
252252import org .apache .calcite .sql .type .SqlTypeName ;
253+ import org .apache .calcite .sql .validate .SqlUserDefinedAggFunction ;
253254import org .apache .calcite .sql .validate .SqlUserDefinedFunction ;
254255import org .apache .calcite .tools .RelBuilder ;
255256import org .apache .commons .lang3 .tuple .Pair ;
@@ -337,14 +338,16 @@ default RexNode resolve(RexBuilder builder, RexNode... args) {
337338 * implementations are independent of any specific data storage, should be registered here
338339 * internally.
339340 */
340- private final ImmutableMap <BuiltinFunctionName , AggHandler > aggFunctionRegistry ;
341+ private final ImmutableMap <BuiltinFunctionName , Pair <CalciteFuncSignature , AggHandler >>
342+ aggFunctionRegistry ;
341343
342344 /**
343345 * The external agg function registry. Agg Functions whose implementations depend on a specific
344346 * data engine should be registered here. This reduces coupling between the core module and
345347 * particular storage backends.
346348 */
347- private final Map <BuiltinFunctionName , AggHandler > aggExternalFunctionRegistry ;
349+ private final Map <BuiltinFunctionName , Pair <CalciteFuncSignature , AggHandler >>
350+ aggExternalFunctionRegistry ;
348351
349352 private PPLFuncImpTable (Builder builder , AggBuilder aggBuilder ) {
350353 final ImmutableMap .Builder <BuiltinFunctionName , List <Pair <CalciteFuncSignature , FunctionImp >>>
@@ -353,8 +356,8 @@ private PPLFuncImpTable(Builder builder, AggBuilder aggBuilder) {
353356 this .functionRegistry = ImmutableMap .copyOf (mapBuilder .build ());
354357 this .externalFunctionRegistry = new ConcurrentHashMap <>();
355358
356- final ImmutableMap .Builder <BuiltinFunctionName , AggHandler > aggMapBuilder =
357- ImmutableMap .builder ();
359+ final ImmutableMap .Builder <BuiltinFunctionName , Pair < CalciteFuncSignature , AggHandler >>
360+ aggMapBuilder = ImmutableMap .builder ();
358361 aggBuilder .map .forEach (aggMapBuilder ::put );
359362 this .aggFunctionRegistry = ImmutableMap .copyOf (aggMapBuilder .build ());
360363 this .aggExternalFunctionRegistry = new ConcurrentHashMap <>();
@@ -370,7 +373,7 @@ public void registerExternalOperator(BuiltinFunctionName functionName, SqlOperat
370373 PPLTypeChecker typeChecker =
371374 wrapSqlOperandTypeChecker (
372375 operator .getOperandTypeChecker (),
373- operator . getName (),
376+ functionName . name (),
374377 operator instanceof SqlUserDefinedFunction );
375378 CalciteFuncSignature signature = new CalciteFuncSignature (functionName .getName (), typeChecker );
376379 externalFunctionRegistry .compute (
@@ -384,14 +387,22 @@ public void registerExternalOperator(BuiltinFunctionName functionName, SqlOperat
384387 }
385388
386389 /**
387- * Register a function implementation from external services dynamically.
390+ * Register an external aggregate operator dynamically.
388391 *
389392 * @param functionName the name of the function, has to be defined in BuiltinFunctionName
390- * @param functionImp the implementation of the agg function
393+ * @param aggFunction a SqlUserDefinedAggFunction representing the aggregate function
394+ * implementation
391395 */
392- public void registerExternalAggFunction (
393- BuiltinFunctionName functionName , AggHandler functionImp ) {
394- aggExternalFunctionRegistry .put (functionName , functionImp );
396+ public void registerExternalAggOperator (
397+ BuiltinFunctionName functionName , SqlUserDefinedAggFunction aggFunction ) {
398+ PPLTypeChecker typeChecker =
399+ wrapSqlOperandTypeChecker (aggFunction .getOperandTypeChecker (), functionName .name (), true );
400+ CalciteFuncSignature signature = new CalciteFuncSignature (functionName .getName (), typeChecker );
401+ AggHandler handler =
402+ (distinct , field , argList , ctx ) ->
403+ UserDefinedFunctionUtils .convertUDAFToAggCall (
404+ aggFunction , List .of (field ), argList , ctx .relBuilder );
405+ aggExternalFunctionRegistry .put (functionName , Pair .of (signature , handler ));
395406 }
396407
397408 public RelBuilder .AggCall resolveAgg (
@@ -400,13 +411,24 @@ public RelBuilder.AggCall resolveAgg(
400411 RexNode field ,
401412 List <RexNode > argList ,
402413 CalcitePlanContext context ) {
403- AggHandler handler = aggExternalFunctionRegistry .get (functionName );
404- if (handler == null ) {
405- handler = aggFunctionRegistry .get (functionName );
414+ var implementation = aggExternalFunctionRegistry .get (functionName );
415+ if (implementation == null ) {
416+ implementation = aggFunctionRegistry .get (functionName );
406417 }
407- if (handler == null ) {
418+ if (implementation == null ) {
408419 throw new IllegalStateException (String .format ("Cannot resolve function: %s" , functionName ));
409420 }
421+ CalciteFuncSignature signature = implementation .getKey ();
422+ RelDataType fieldType = field .getType ();
423+ if (!signature .match (functionName .getName (), List .of (fieldType ))) {
424+ throw new ExpressionEvaluationException (
425+ String .format (
426+ "Aggregation function %s expects field type {%s}, but got %s" ,
427+ functionName ,
428+ signature .typeChecker ().getAllowedSignatures (),
429+ getActualSignature (List .of (fieldType ))));
430+ }
431+ var handler = implementation .getValue ();
410432 return handler .apply (distinct , field , argList , context );
411433 }
412434
@@ -1037,43 +1059,66 @@ void register(
10371059 }
10381060
10391061 private static class AggBuilder {
1040- private final Map <BuiltinFunctionName , AggHandler > map = new HashMap <>();
1062+ private final Map <BuiltinFunctionName , Pair <CalciteFuncSignature , AggHandler >> map =
1063+ new HashMap <>();
10411064
1042- void register (BuiltinFunctionName functionName , AggHandler aggHandler ) {
1043- map .put (functionName , aggHandler );
1065+ void register (
1066+ BuiltinFunctionName functionName , AggHandler aggHandler , PPLTypeChecker typeChecker ) {
1067+ CalciteFuncSignature signature =
1068+ new CalciteFuncSignature (functionName .getName (), typeChecker );
1069+ map .put (functionName , Pair .of (signature , aggHandler ));
1070+ }
1071+
1072+ void registerOperator (BuiltinFunctionName functionName , SqlUserDefinedAggFunction aggFunction ) {
1073+ PPLTypeChecker typeChecker =
1074+ wrapSqlOperandTypeChecker (aggFunction .getOperandTypeChecker (), functionName .name (), true );
1075+ AggHandler handler =
1076+ (distinct , field , argList , ctx ) ->
1077+ UserDefinedFunctionUtils .convertUDAFToAggCall (
1078+ aggFunction , List .of (field ), argList , ctx .relBuilder );
1079+ register (functionName , handler , typeChecker );
10441080 }
10451081
10461082 void populate () {
1047- register (MAX , (distinct , field , argList , ctx ) -> ctx .relBuilder .max (field ));
1048- register (MIN , (distinct , field , argList , ctx ) -> ctx .relBuilder .min (field ));
1083+ register (MAX , (distinct , field , argList , ctx ) -> ctx .relBuilder .max (field ), null );
1084+ register (MIN , (distinct , field , argList , ctx ) -> ctx .relBuilder .min (field ), null );
10491085
1050- register (AVG , (distinct , field , argList , ctx ) -> ctx .relBuilder .avg (distinct , null , field ));
1086+ register (
1087+ AVG , (distinct , field , argList , ctx ) -> ctx .relBuilder .avg (distinct , null , field ), null );
10511088
10521089 register (
10531090 COUNT ,
10541091 (distinct , field , argList , ctx ) ->
10551092 ctx .relBuilder .count (
1056- distinct , null , field == null ? ImmutableList .of () : ImmutableList .of (field )));
1057- register (SUM , (distinct , field , argList , ctx ) -> ctx .relBuilder .sum (distinct , null , field ));
1093+ distinct , null , field == null ? ImmutableList .of () : ImmutableList .of (field )),
1094+ null );
1095+ register (
1096+ SUM ,
1097+ (distinct , field , argList , ctx ) ->
1098+ ctx .relBuilder .aggregateCall (SqlStdOperatorTable .SUM , field ),
1099+ null );
10581100
10591101 register (
10601102 VARSAMP ,
1061- (distinct , field , argList , ctx ) ->
1062- ctx . relBuilder . aggregateCall ( VAR_SAMP_NULLABLE , field ) );
1103+ (distinct , field , argList , ctx ) -> ctx . relBuilder . aggregateCall ( VAR_SAMP_NULLABLE , field ),
1104+ null );
10631105
10641106 register (
10651107 VARPOP ,
1066- (distinct , field , argList , ctx ) -> ctx .relBuilder .aggregateCall (VAR_POP_NULLABLE , field ));
1108+ (distinct , field , argList , ctx ) -> ctx .relBuilder .aggregateCall (VAR_POP_NULLABLE , field ),
1109+ null );
10671110
10681111 register (
10691112 STDDEV_SAMP ,
10701113 (distinct , field , argList , ctx ) ->
1071- ctx .relBuilder .aggregateCall (STDDEV_SAMP_NULLABLE , field ));
1114+ ctx .relBuilder .aggregateCall (STDDEV_SAMP_NULLABLE , field ),
1115+ null );
10721116
10731117 register (
10741118 STDDEV_POP ,
10751119 (distinct , field , argList , ctx ) ->
1076- ctx .relBuilder .aggregateCall (STDDEV_POP_NULLABLE , field ));
1120+ ctx .relBuilder .aggregateCall (STDDEV_POP_NULLABLE , field ),
1121+ null );
10771122
10781123 register (
10791124 TAKE ,
@@ -1087,7 +1132,8 @@ void populate() {
10871132 List .of (field ),
10881133 newArgList ,
10891134 ctx .relBuilder );
1090- });
1135+ },
1136+ null );
10911137
10921138 register (
10931139 PERCENTILE_APPROX ,
@@ -1102,7 +1148,8 @@ void populate() {
11021148 List .of (field ),
11031149 newArgList ,
11041150 ctx .relBuilder );
1105- });
1151+ },
1152+ null );
11061153
11071154 register (
11081155 INTERNAL_PATTERN ,
@@ -1113,7 +1160,8 @@ void populate() {
11131160 ReturnTypes .explicit (UserDefinedFunctionUtils .nullablePatternAggList ),
11141161 List .of (field ),
11151162 argList ,
1116- ctx .relBuilder ));
1163+ ctx .relBuilder ),
1164+ null );
11171165 }
11181166 }
11191167}
0 commit comments