Skip to content

Commit fd94ea0

Browse files
committed
Refactor type checking for UDT (specifically, IP)
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 1a73843 commit fd94ea0

8 files changed

Lines changed: 133 additions & 147 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/ExtendedRexBuilder.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.opensearch.sql.ast.expression.SpanUnit;
2828
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
2929
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
30+
import org.opensearch.sql.data.type.ExprCoreType;
31+
import org.opensearch.sql.exception.ExpressionEvaluationException;
3032
import org.opensearch.sql.exception.SemanticCheckException;
3133
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
3234

@@ -129,13 +131,28 @@ public RexNode makeCast(
129131
}
130132
} else if (OpenSearchTypeFactory.isUserDefinedType(type)) {
131133
var udt = ((AbstractExprRelDataType<?>) type).getUdt();
134+
var argExprType = OpenSearchTypeFactory.convertRelDataTypeToExprType(exp.getType());
132135
return switch (udt) {
133136
case EXPR_DATE -> makeCall(type, PPLBuiltinOperators.DATE, List.of(exp));
134137
case EXPR_TIME -> makeCall(type, PPLBuiltinOperators.TIME, List.of(exp));
135138
case EXPR_TIMESTAMP -> makeCall(type, PPLBuiltinOperators.TIMESTAMP, List.of(exp));
136-
case EXPR_IP -> makeCall(type, PPLBuiltinOperators.CAST_IP, List.of(exp));
139+
case EXPR_IP -> {
140+
if (argExprType == ExprCoreType.IP) {
141+
yield exp;
142+
} else if (argExprType == ExprCoreType.STRING) {
143+
yield makeCall(type, PPLBuiltinOperators.CAST_IP, List.of(exp));
144+
}
145+
// Throwing error inside implementation will be suppressed by Calcite, thus
146+
// throwing 500 error. Therefore, we throw error here to ensure the error
147+
// information is displayed properly.
148+
throw new ExpressionEvaluationException(
149+
String.format(
150+
Locale.ROOT,
151+
"Cannot convert %s to IP, only STRING and IP types are supported",
152+
argExprType));
153+
}
137154
default -> throw new SemanticCheckException(
138-
String.format(Locale.ROOT, "Unsupported cast type: %s", udt.name()));
155+
String.format(Locale.ROOT, "Cannot cast from %s to %s", argExprType, udt.name()));
139156
};
140157
}
141158
return super.makeCast(pos, type, exp, matchNullability, safe, format);

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -528,18 +528,9 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op
528528
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
529529
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
530530
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
531-
} else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
532-
register(
533-
functionName,
534-
createFunctionImpWithTypeChecker(
535-
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
536-
new PPLTypeChecker.PPLIPCompareTypeChecker()));
537-
} else if (typeChecker instanceof UDFOperandMetadata.CidrOperandMetadata) {
538-
register(
539-
functionName,
540-
createFunctionImpWithTypeChecker(
541-
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
542-
new PPLTypeChecker.PPLCidrTypeChecker()));
531+
} else if (typeChecker
532+
instanceof UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) {
533+
register(functionName, wrapWithUdtTypeChecker(operator, udtOperandMetadata));
543534
} else {
544535
logger.info(
545536
"Cannot create type checker for function: {}. Will skip its type checking",
@@ -558,6 +549,13 @@ private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
558549
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
559550
}
560551

552+
// Such wrapWith*TypeChecker methods are useful in that we don't have to create explicit
553+
// overrides of resolve function for different number of operands.
554+
// I.e. we don't have to explicitly call
555+
// (FuncImp1) (builder, arg1) -> builder.makeCall(operator, arg1);
556+
// (FuncImp2) (builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2);
557+
// etc.
558+
561559
/**
562560
* Wrap a SqlOperator into a FunctionImp with a composite type checker.
563561
*
@@ -624,6 +622,21 @@ public PPLTypeChecker getTypeChecker() {
624622
};
625623
}
626624

625+
private static FunctionImp wrapWithUdtTypeChecker(
626+
SqlOperator operator, UDFOperandMetadata.UDTOperandMetadata udtOperandMetadata) {
627+
return new FunctionImp() {
628+
@Override
629+
public RexNode resolve(RexBuilder builder, RexNode... args) {
630+
return builder.makeCall(operator, args);
631+
}
632+
633+
@Override
634+
public PPLTypeChecker getTypeChecker() {
635+
return PPLTypeChecker.wrapUDT(udtOperandMetadata.allowedParamTypes());
636+
}
637+
};
638+
}
639+
627640
private static FunctionImp createFunctionImpWithTypeChecker(
628641
BiFunction<RexBuilder, RexNode, RexNode> resolver, PPLTypeChecker typeChecker) {
629642
return new FunctionImp1() {

core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.apache.calcite.sql.type.SqlTypeFamily;
2424
import org.apache.calcite.sql.type.SqlTypeName;
2525
import org.apache.calcite.sql.type.SqlTypeUtil;
26-
import org.opensearch.sql.calcite.type.ExprIPType;
2726
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
2827
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
2928
import org.opensearch.sql.data.type.ExprCoreType;
@@ -257,53 +256,6 @@ public String getAllowedSignatures() {
257256
}
258257
}
259258

260-
class PPLIPCompareTypeChecker implements PPLTypeChecker {
261-
@Override
262-
public boolean checkOperandTypes(List<RelDataType> types) {
263-
if (types.size() != 2) {
264-
return false;
265-
}
266-
RelDataType type1 = types.get(0);
267-
RelDataType type2 = types.get(1);
268-
return areIpAndStringTypes(type1, type2)
269-
|| areIpAndStringTypes(type2, type1)
270-
|| (type1 instanceof ExprIPType && type2 instanceof ExprIPType);
271-
}
272-
273-
@Override
274-
public String getAllowedSignatures() {
275-
// Will be merged with the allowed signatures of comparable type checker,
276-
// shown as [COMPARABLE_TYPE,COMPARABLE_TYPE]
277-
return "";
278-
}
279-
280-
private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
281-
return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER;
282-
}
283-
}
284-
285-
class PPLCidrTypeChecker implements PPLTypeChecker {
286-
@Override
287-
public boolean checkOperandTypes(List<RelDataType> types) {
288-
if (types.size() != 2) {
289-
return false;
290-
}
291-
RelDataType type1 = types.get(0);
292-
RelDataType type2 = types.get(1);
293-
294-
// accept (STRING, STRING) or (IP, STRING)
295-
if (type2.getFamily() != SqlTypeFamily.CHARACTER) {
296-
return false;
297-
}
298-
return type1 instanceof ExprIPType || type1.getFamily() == SqlTypeFamily.CHARACTER;
299-
}
300-
301-
@Override
302-
public String getAllowedSignatures() {
303-
return "[STRING,STRING],[IP,STRING]";
304-
}
305-
}
306-
307259
/**
308260
* Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand
309261
* belongs to its corresponding {@link SqlTypeFamily}.
@@ -379,6 +331,42 @@ static PPLComparableTypeChecker wrapComparable(SameOperandTypeChecker typeChecke
379331
return new PPLComparableTypeChecker(typeChecker);
380332
}
381333

334+
/**
335+
* Create a {@link PPLTypeChecker} from a list of allowed signatures consisted of {@link
336+
* ExprType}. This is useful to validate arguments against user-defined types (UDT) that does not
337+
* match any Calcite {@link SqlTypeFamily}.
338+
*
339+
* @param allowedSignatures a list of allowed signatures, where each signature is a list of {@link
340+
* ExprType} representing the expected types of the function arguments.
341+
* @return a {@link PPLTypeChecker} that checks if the operand types match any of the allowed
342+
* signatures
343+
*/
344+
static PPLTypeChecker wrapUDT(List<List<ExprType>> allowedSignatures) {
345+
return new PPLTypeChecker() {
346+
@Override
347+
public boolean checkOperandTypes(List<RelDataType> types) {
348+
List<ExprType> argExprTypes =
349+
types.stream().map(OpenSearchTypeFactory::convertRelDataTypeToExprType).toList();
350+
for (var allowedSignature : allowedSignatures) {
351+
if (allowedSignature.size() != types.size()) {
352+
continue; // Skip signatures that do not match the operand count
353+
}
354+
// Check if the argument types match the allowed signature
355+
if (IntStream.range(0, allowedSignature.size())
356+
.allMatch(i -> allowedSignature.get(i).equals(argExprTypes.get(i)))) {
357+
return true;
358+
}
359+
}
360+
return false;
361+
}
362+
363+
@Override
364+
public String getAllowedSignatures() {
365+
return PPLTypeChecker.getExprFamilySignature(allowedSignatures);
366+
}
367+
};
368+
}
369+
382370
// Util Functions
383371
/**
384372
* Generates a list of allowed function signatures based on the provided {@link
@@ -464,6 +452,10 @@ private static String getFamilySignature(List<SqlTypeFamily> families) {
464452
List<List<ExprType>> signatures = Lists.cartesianProduct(exprTypes);
465453

466454
// Convert each signature to a string representation and then concatenate them
455+
return getExprFamilySignature(signatures);
456+
}
457+
458+
private static String getExprFamilySignature(List<List<ExprType>> signatures) {
467459
return signatures.stream()
468460
.map(
469461
types ->

core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java

Lines changed: 4 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.calcite.sql.type.SqlOperandMetadata;
1818
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
1919
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
20+
import org.opensearch.sql.data.type.ExprType;
2021

2122
/**
2223
* This class is created for the compatibility with {@link SqlUserDefinedFunction} constructors when
@@ -105,83 +106,11 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
105106
};
106107
}
107108

108-
/**
109-
* A named class that serves as an identifier for IP comparator's operand metadata. It does not
110-
* implement any actual type checking logic.
111-
*/
112-
class IPOperandMetadata implements UDFOperandMetadata {
113-
@Override
114-
public SqlOperandTypeChecker getInnerTypeChecker() {
115-
return this;
116-
}
117-
118-
@Override
119-
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
120-
return List.of();
121-
}
122-
123-
@Override
124-
public List<String> paramNames() {
125-
return List.of();
126-
}
127-
128-
@Override
129-
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
130-
return false;
131-
}
132-
133-
@Override
134-
public SqlOperandCountRange getOperandCountRange() {
135-
return null;
136-
}
137-
138-
@Override
139-
public String getAllowedSignatures(SqlOperator op, String opName) {
140-
return "";
141-
}
142-
}
143-
144-
/**
145-
* A named class that serves as an identifier for cidr's operand metadata. It does not implement
146-
* any actual type checking logic.
147-
*/
148-
class CidrOperandMetadata implements UDFOperandMetadata {
149-
@Override
150-
public SqlOperandTypeChecker getInnerTypeChecker() {
151-
return this;
152-
}
153-
154-
@Override
155-
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
156-
return List.of();
157-
}
158-
159-
@Override
160-
public List<String> paramNames() {
161-
return List.of();
162-
}
163-
164-
@Override
165-
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
166-
return false;
167-
}
168-
169-
@Override
170-
public SqlOperandCountRange getOperandCountRange() {
171-
return null;
172-
}
173-
174-
@Override
175-
public String getAllowedSignatures(SqlOperator op, String opName) {
176-
return "";
177-
}
109+
static UDFOperandMetadata wrapUDT(List<List<ExprType>> allowSignatures) {
110+
return new UDTOperandMetadata(allowSignatures);
178111
}
179112

180-
/**
181-
* A named class that serves as an identifier for IP cast's operand metadata. It does not
182-
* implement any actual type checking logic.
183-
*/
184-
class IPCastOperandMetadata implements UDFOperandMetadata {
113+
record UDTOperandMetadata(List<List<ExprType>> allowedParamTypes) implements UDFOperandMetadata {
185114
@Override
186115
public SqlOperandTypeChecker getInnerTypeChecker() {
187116
return this;

core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CidrMatchFunction.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.sql.data.model.ExprIpValue;
1818
import org.opensearch.sql.data.model.ExprValue;
1919
import org.opensearch.sql.data.model.ExprValueUtils;
20+
import org.opensearch.sql.data.type.ExprCoreType;
2021
import org.opensearch.sql.expression.function.ImplementorUDF;
2122
import org.opensearch.sql.expression.function.UDFOperandMetadata;
2223
import org.opensearch.sql.expression.ip.IPFunctions;
@@ -46,7 +47,10 @@ public UDFOperandMetadata getOperandMetadata() {
4647
// EXPR_IP is mapped to SqlTypeFamily.OTHER in
4748
// UserDefinedFunctionUtils.convertRelDataTypeToSqlTypeName
4849
// We use a specific type checker to serve
49-
return new UDFOperandMetadata.CidrOperandMetadata();
50+
return UDFOperandMetadata.wrapUDT(
51+
List.of(
52+
List.of(ExprCoreType.IP, ExprCoreType.STRING),
53+
List.of(ExprCoreType.STRING, ExprCoreType.STRING)));
5054
}
5155

5256
public static class CidrMatchImplementor implements NotNullImplementor {

core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.apache.calcite.sql.type.ReturnTypes;
1616
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1717
import org.opensearch.sql.data.model.ExprIpValue;
18+
import org.opensearch.sql.data.type.ExprCoreType;
1819
import org.opensearch.sql.expression.function.ImplementorUDF;
1920
import org.opensearch.sql.expression.function.UDFOperandMetadata;
2021

@@ -66,7 +67,11 @@ public SqlReturnTypeInference getReturnTypeInference() {
6667

6768
@Override
6869
public UDFOperandMetadata getOperandMetadata() {
69-
return new UDFOperandMetadata.IPOperandMetadata();
70+
return UDFOperandMetadata.wrapUDT(
71+
List.of(
72+
List.of(ExprCoreType.IP, ExprCoreType.IP),
73+
List.of(ExprCoreType.IP, ExprCoreType.STRING),
74+
List.of(ExprCoreType.STRING, ExprCoreType.IP)));
7075
}
7176

7277
public static class CompareImplementor implements NotNullImplementor {

0 commit comments

Comments
 (0)