Skip to content

Commit a9d1ba1

Browse files
committed
Add interface registerOperator(BuiltinFunctionName, SqlOperator, PPLTypeChecker) to allow registering an operator with a designated type chekcer
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent ffefc08 commit a9d1ba1

1 file changed

Lines changed: 141 additions & 127 deletions

File tree

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 141 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -564,73 +564,6 @@ private void compulsoryCast(
564564
return null;
565565
}
566566

567-
/**
568-
* Get a string representation of the argument types expressed in ExprType for error messages.
569-
*
570-
* @param argTypes the list of argument types as {@link RelDataType}
571-
* @return a string in the format [type1,type2,...] representing the argument types
572-
*/
573-
private static String getActualSignature(List<RelDataType> argTypes) {
574-
return "["
575-
+ argTypes.stream()
576-
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
577-
.map(Objects::toString)
578-
.collect(Collectors.joining(","))
579-
+ "]";
580-
}
581-
582-
/**
583-
* Wraps a {@link SqlOperandTypeChecker} into a {@link PPLTypeChecker} for use in function
584-
* signature validation.
585-
*
586-
* @param typeChecker the original SQL operand type checker
587-
* @param functionName the name of the function for error reporting
588-
* @param isUserDefinedFunction true if the function is user-defined, false otherwise
589-
* @return a {@link PPLTypeChecker} that delegates to the provided {@code typeChecker}
590-
*/
591-
private static PPLTypeChecker wrapSqlOperandTypeChecker(
592-
SqlOperandTypeChecker typeChecker, String functionName, boolean isUserDefinedFunction) {
593-
PPLTypeChecker pplTypeChecker;
594-
// Only the composite operand type checker for UDFs are concerned here.
595-
if (isUserDefinedFunction
596-
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
597-
// UDFs implement their own composite type checkers, which always use OR logic for
598-
// argument
599-
// types. Verifying the composition type would require accessing a protected field in
600-
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
601-
// be skipped, so we avoid checking the composition type here.
602-
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, false);
603-
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
604-
pplTypeChecker = PPLTypeChecker.wrapFamily(implicitCastTypeChecker);
605-
} else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
606-
// If compositeTypeChecker contains operand checkers other than family type checkers or
607-
// other than OR compositions, the function with be registered with a null type checker,
608-
// which means the function will not be type checked.
609-
try {
610-
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, true);
611-
} catch (IllegalArgumentException | UnsupportedOperationException e) {
612-
logger.debug(
613-
String.format(
614-
"Failed to create composite type checker for operator: %s. Will skip its type"
615-
+ " checking",
616-
functionName),
617-
e);
618-
pplTypeChecker = null;
619-
}
620-
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
621-
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
622-
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
623-
pplTypeChecker = PPLTypeChecker.wrapComparable(comparableTypeChecker);
624-
} else if (typeChecker instanceof UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) {
625-
pplTypeChecker = PPLTypeChecker.wrapUDT(udtOperandMetadata.allowedParamTypes());
626-
} else {
627-
logger.info(
628-
"Cannot create type checker for function: {}. Will skip its type checking", functionName);
629-
pplTypeChecker = null;
630-
}
631-
return pplTypeChecker;
632-
}
633-
634567
@SuppressWarnings({"UnusedReturnValue", "SameParameterValue"})
635568
private abstract static class AbstractBuilder {
636569

@@ -650,7 +583,7 @@ abstract void register(
650583
* @param operators the operators to associate with this function name, tried in sequence until
651584
* one matches the argument types during resolution
652585
*/
653-
public void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
586+
protected void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
654587
for (SqlOperator operator : operators) {
655588
SqlOperandTypeChecker typeChecker;
656589
if (operator instanceof SqlUserDefinedFunction udfOperator) {
@@ -662,18 +595,24 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op
662595
PPLTypeChecker pplTypeChecker =
663596
wrapSqlOperandTypeChecker(
664597
typeChecker, operator.getName(), operator instanceof SqlUserDefinedFunction);
665-
register(
666-
functionName,
667-
(RexBuilder builder, RexNode... args) -> builder.makeCall(operator, args),
668-
pplTypeChecker);
598+
registerOperator(functionName, operator, pplTypeChecker);
669599
}
670600
}
671601

672-
private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
673-
SqlUserDefinedFunction udfOperator) {
674-
UDFOperandMetadata udfOperandMetadata =
675-
(UDFOperandMetadata) udfOperator.getOperandTypeChecker();
676-
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
602+
/**
603+
* Registers an operator for a built-in function name with a specified {@link PPLTypeChecker}.
604+
* This allows custom type checking logic to be associated with the operator.
605+
*
606+
* @param functionName the built-in function name
607+
* @param operator the SQL operator to register
608+
* @param typeChecker the type checker to use for validating argument types
609+
*/
610+
protected void registerOperator(
611+
BuiltinFunctionName functionName, SqlOperator operator, PPLTypeChecker typeChecker) {
612+
register(
613+
functionName,
614+
(RexBuilder builder, RexNode... args) -> builder.makeCall(operator, args),
615+
typeChecker);
677616
}
678617

679618
void populate() {
@@ -690,15 +629,6 @@ void populate() {
690629
registerOperator(OR, SqlStdOperatorTable.OR);
691630
registerOperator(NOT, SqlStdOperatorTable.NOT);
692631

693-
// Register ADD (+ symbol) for numeric addition
694-
register(
695-
ADD,
696-
(RexBuilder builder, RexNode... args) -> builder.makeCall(SqlStdOperatorTable.PLUS, args),
697-
new PPLTypeChecker.PPLFamilyTypeChecker(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
698-
699-
// Register ADD (+ symbol) for string concatenation
700-
registerOperator(ADD, SqlStdOperatorTable.CONCAT);
701-
702632
// Register ADDFUNCTION for numeric addition only
703633
registerOperator(ADDFUNCTION, SqlStdOperatorTable.PLUS);
704634
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
@@ -853,15 +783,16 @@ void populate() {
853783
registerOperator(WEEK, PPLBuiltinOperators.WEEK);
854784
registerOperator(WEEK_OF_YEAR, PPLBuiltinOperators.WEEK);
855785
registerOperator(WEEKOFYEAR, PPLBuiltinOperators.WEEK);
856-
registerOperator(INTERNAL_PATTERN_PARSER, PPLBuiltinOperators.PATTERN_PARSER);
857786

787+
registerOperator(INTERNAL_PATTERN_PARSER, PPLBuiltinOperators.PATTERN_PARSER);
858788
registerOperator(ARRAY, PPLBuiltinOperators.ARRAY);
859789
registerOperator(ARRAY_LENGTH, SqlLibraryOperators.ARRAY_LENGTH);
860790
registerOperator(FORALL, PPLBuiltinOperators.FORALL);
861791
registerOperator(EXISTS, PPLBuiltinOperators.EXISTS);
862792
registerOperator(FILTER, PPLBuiltinOperators.FILTER);
863793
registerOperator(TRANSFORM, PPLBuiltinOperators.TRANSFORM);
864794
registerOperator(REDUCE, PPLBuiltinOperators.REDUCE);
795+
865796
// Register Json function
866797
register(
867798
JSON_ARRAY,
@@ -889,6 +820,53 @@ void populate() {
889820
registerOperator(JSON_APPEND, PPLBuiltinOperators.JSON_APPEND);
890821
registerOperator(JSON_EXTEND, PPLBuiltinOperators.JSON_EXTEND);
891822

823+
// Register operators with a different type checker
824+
825+
// Register ADD (+ symbol) for string concatenation
826+
// Replaced type checker since CONCAT also supports array concatenation
827+
registerOperator(
828+
ADD,
829+
SqlStdOperatorTable.CONCAT,
830+
PPLTypeChecker.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER));
831+
// Register ADD (+ symbol) for numeric addition
832+
// Replace type checker since PLUS also supports binary addition
833+
registerOperator(
834+
ADD,
835+
SqlStdOperatorTable.PLUS,
836+
PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
837+
// Replace with a custom CompositeOperandTypeChecker to check both operands as
838+
// SqlStdOperatorTable.ITEM.getOperandTypeChecker() checks only the first operand instead
839+
// of all operands.
840+
registerOperator(
841+
INTERNAL_ITEM,
842+
SqlStdOperatorTable.ITEM,
843+
PPLTypeChecker.wrapComposite(
844+
(CompositeOperandTypeChecker)
845+
OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER)
846+
.or(OperandTypes.family(SqlTypeFamily.MAP, SqlTypeFamily.ANY)),
847+
false));
848+
registerOperator(
849+
XOR,
850+
SqlStdOperatorTable.NOT_EQUALS,
851+
PPLTypeChecker.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.BOOLEAN));
852+
// SqlStdOperatorTable.CASE.getOperandTypeChecker is null. We manually create a type checker
853+
// for it. The second and third operands are required to be of the same type. If not,
854+
// it will throw an IllegalArgumentException with information Can't find leastRestrictive type
855+
registerOperator(
856+
IF,
857+
SqlStdOperatorTable.CASE,
858+
PPLTypeChecker.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ANY));
859+
// Re-define the type checker for is not null, is present, and is null since their original
860+
// type checker ANY isn't compatible with struct types.
861+
registerOperator(
862+
IS_NOT_NULL,
863+
SqlStdOperatorTable.IS_NOT_NULL,
864+
PPLTypeChecker.family(SqlTypeFamily.IGNORE));
865+
registerOperator(
866+
IS_PRESENT, SqlStdOperatorTable.IS_NOT_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
867+
registerOperator(
868+
IS_NULL, SqlStdOperatorTable.IS_NULL, PPLTypeChecker.family(SqlTypeFamily.IGNORE));
869+
892870
// Register implementation.
893871
// Note, make the implementation an individual class if too complex.
894872
register(
@@ -922,10 +900,9 @@ void populate() {
922900
builder.makeLiteral(" "),
923901
arg),
924902
PPLTypeChecker.family(SqlTypeFamily.CHARACTER));
925-
register(
903+
registerOperator(
926904
ATAN,
927-
(FunctionImp2)
928-
(builder, arg1, arg2) -> builder.makeCall(SqlStdOperatorTable.ATAN2, arg1, arg2),
905+
SqlStdOperatorTable.ATAN2,
929906
PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
930907
register(
931908
STRCMP,
@@ -960,17 +937,6 @@ void populate() {
960937
SqlTypeFamily.INTEGER,
961938
SqlTypeFamily.INTEGER)),
962939
false));
963-
// SqlStdOperatorTable.ITEM.getOperandTypeChecker() checks only the first operand instead of
964-
// all operands. Therefore, we wrap it with a custom CompositeOperandTypeChecker to check both
965-
// operands.
966-
register(
967-
INTERNAL_ITEM,
968-
(RexBuilder builder, RexNode... args) -> builder.makeCall(SqlStdOperatorTable.ITEM, args),
969-
PPLTypeChecker.wrapComposite(
970-
(CompositeOperandTypeChecker)
971-
OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER)
972-
.or(OperandTypes.family(SqlTypeFamily.MAP, SqlTypeFamily.ANY)),
973-
false));
974940
register(
975941
LOG,
976942
(FunctionImp2)
@@ -1002,18 +968,6 @@ void populate() {
1002968
(builder, arg) ->
1003969
builder.makeLiteral(getLegacyTypeName(arg.getType(), QueryType.PPL)),
1004970
null);
1005-
register(
1006-
XOR,
1007-
(FunctionImp2)
1008-
(builder, arg1, arg2) -> builder.makeCall(SqlStdOperatorTable.NOT_EQUALS, arg1, arg2),
1009-
PPLTypeChecker.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.BOOLEAN));
1010-
// SqlStdOperatorTable.CASE.getOperandTypeChecker is null. We manually create a type checker
1011-
// for it. The second and third operands are required to be of the same type. If not,
1012-
// it will throw an IllegalArgumentException with information Can't find leastRestrictive type
1013-
register(
1014-
IF,
1015-
(RexBuilder builder, RexNode... args) -> builder.makeCall(SqlStdOperatorTable.CASE, args),
1016-
PPLTypeChecker.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ANY));
1017971
register(
1018972
NULLIF,
1019973
(FunctionImp2)
@@ -1048,20 +1002,6 @@ void populate() {
10481002
builder.makeLiteral(" "),
10491003
arg))),
10501004
PPLTypeChecker.family(SqlTypeFamily.ANY));
1051-
// Re-define the type checker for is not null, is present, and is null since their original
1052-
// type checker ANY isn't compatible with struct types.
1053-
register(
1054-
IS_NOT_NULL,
1055-
(builder, args) -> builder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, args),
1056-
PPLTypeChecker.family(SqlTypeFamily.IGNORE));
1057-
register(
1058-
IS_PRESENT,
1059-
(builder, args) -> builder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, args),
1060-
PPLTypeChecker.family(SqlTypeFamily.IGNORE));
1061-
register(
1062-
IS_NULL,
1063-
(builder, args) -> builder.makeCall(SqlStdOperatorTable.IS_NULL, args),
1064-
PPLTypeChecker.family(SqlTypeFamily.IGNORE));
10651005
register(
10661006
LIKE,
10671007
(FunctionImp2)
@@ -1075,6 +1015,13 @@ void populate() {
10751015
builder.makeLiteral("\\")),
10761016
PPLTypeChecker.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING));
10771017
}
1018+
1019+
private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
1020+
SqlUserDefinedFunction udfOperator) {
1021+
UDFOperandMetadata udfOperandMetadata =
1022+
(UDFOperandMetadata) udfOperator.getOperandTypeChecker();
1023+
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
1024+
}
10781025
}
10791026

10801027
private static class Builder extends AbstractBuilder {
@@ -1213,4 +1160,71 @@ void populate() {
12131160
null);
12141161
}
12151162
}
1163+
1164+
/**
1165+
* Get a string representation of the argument types expressed in ExprType for error messages.
1166+
*
1167+
* @param argTypes the list of argument types as {@link RelDataType}
1168+
* @return a string in the format [type1,type2,...] representing the argument types
1169+
*/
1170+
private static String getActualSignature(List<RelDataType> argTypes) {
1171+
return "["
1172+
+ argTypes.stream()
1173+
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
1174+
.map(Objects::toString)
1175+
.collect(Collectors.joining(","))
1176+
+ "]";
1177+
}
1178+
1179+
/**
1180+
* Wraps a {@link SqlOperandTypeChecker} into a {@link PPLTypeChecker} for use in function
1181+
* signature validation.
1182+
*
1183+
* @param typeChecker the original SQL operand type checker
1184+
* @param functionName the name of the function for error reporting
1185+
* @param isUserDefinedFunction true if the function is user-defined, false otherwise
1186+
* @return a {@link PPLTypeChecker} that delegates to the provided {@code typeChecker}
1187+
*/
1188+
private static PPLTypeChecker wrapSqlOperandTypeChecker(
1189+
SqlOperandTypeChecker typeChecker, String functionName, boolean isUserDefinedFunction) {
1190+
PPLTypeChecker pplTypeChecker;
1191+
// Only the composite operand type checker for UDFs are concerned here.
1192+
if (isUserDefinedFunction
1193+
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
1194+
// UDFs implement their own composite type checkers, which always use OR logic for
1195+
// argument
1196+
// types. Verifying the composition type would require accessing a protected field in
1197+
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
1198+
// be skipped, so we avoid checking the composition type here.
1199+
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, false);
1200+
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
1201+
pplTypeChecker = PPLTypeChecker.wrapFamily(implicitCastTypeChecker);
1202+
} else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
1203+
// If compositeTypeChecker contains operand checkers other than family type checkers or
1204+
// other than OR compositions, the function with be registered with a null type checker,
1205+
// which means the function will not be type checked.
1206+
try {
1207+
pplTypeChecker = PPLTypeChecker.wrapComposite(compositeTypeChecker, true);
1208+
} catch (IllegalArgumentException | UnsupportedOperationException e) {
1209+
logger.debug(
1210+
String.format(
1211+
"Failed to create composite type checker for operator: %s. Will skip its type"
1212+
+ " checking",
1213+
functionName),
1214+
e);
1215+
pplTypeChecker = null;
1216+
}
1217+
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
1218+
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
1219+
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
1220+
pplTypeChecker = PPLTypeChecker.wrapComparable(comparableTypeChecker);
1221+
} else if (typeChecker instanceof UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) {
1222+
pplTypeChecker = PPLTypeChecker.wrapUDT(udtOperandMetadata.allowedParamTypes());
1223+
} else {
1224+
logger.info(
1225+
"Cannot create type checker for function: {}. Will skip its type checking", functionName);
1226+
pplTypeChecker = null;
1227+
}
1228+
return pplTypeChecker;
1229+
}
12161230
}

0 commit comments

Comments
 (0)