Skip to content

Commit 78ba6fb

Browse files
committed
Create a new PPLIPCompareTypeChecker for comparing IP
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 73bec7b commit 78ba6fb

6 files changed

Lines changed: 78 additions & 50 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,9 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) {
115115
case EXPR_DATE -> SqlTypeName.DATE;
116116
case EXPR_TIME -> SqlTypeName.TIME;
117117
case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP;
118-
// EXPR_IP is mapped to SqlTypeName.NULL since there is no
119-
// corresponding SqlTypeName in Calcite. This is a workaround to allow
120-
// type checking for IP types in UDFs.
121-
case EXPR_IP -> SqlTypeName.NULL;
118+
// EXPR_IP is mapped to SqlTypeName.OTHER since there is no
119+
// corresponding SqlTypeName in Calcite.
120+
case EXPR_IP -> SqlTypeName.OTHER;
122121
case EXPR_BINARY -> SqlTypeName.VARBINARY;
123122
default -> type.getSqlTypeName();
124123
};

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,10 @@ functionName, getActualSignature(argTypes), e.getMessage()),
450450
}
451451
StringJoiner allowedSignatures = new StringJoiner(",");
452452
for (var implement : implementList) {
453-
allowedSignatures.add(implement.getKey().typeChecker().getAllowedSignatures());
453+
String signature = implement.getKey().typeChecker().getAllowedSignatures();
454+
if (!signature.isEmpty()) {
455+
allowedSignatures.add(signature);
456+
}
454457
}
455458
throw new ExpressionEvaluationException(
456459
String.format(
@@ -500,6 +503,12 @@ void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) {
500503
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
501504
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
502505
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
506+
} else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
507+
register(
508+
functionName,
509+
createFunctionImpWithTypeChecker(
510+
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
511+
new PPLTypeChecker.PPLIPCompareTypeChecker()));
503512
} else {
504513
logger.info(
505514
"Cannot create type checker for function: {}. Will skip its type checking",

core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import org.apache.calcite.sql.type.SqlTypeFamily;
2424
import org.apache.calcite.sql.type.SqlTypeName;
2525
import org.apache.calcite.sql.type.SqlTypeUtil;
26-
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
26+
import org.opensearch.sql.calcite.type.ExprIPType;
2727
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
2828
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
2929
import org.opensearch.sql.data.type.ExprCoreType;
@@ -215,10 +215,6 @@ public boolean checkOperandTypes(List<RelDataType> types) {
215215
RelDataType type_l = types.get(i);
216216
RelDataType type_r = types.get(i + 1);
217217
if (!SqlTypeUtil.isComparable(type_l, type_r)) {
218-
if (areIpAndStringTypes(type_l, type_r) || areIpAndStringTypes(type_r, type_l)) {
219-
// Allow IP and string comparison
220-
continue;
221-
}
222218
return false;
223219
}
224220
// Disallow coercing between strings and numeric, boolean
@@ -239,14 +235,6 @@ private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) {
239235
};
240236
}
241237

242-
private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
243-
if (typeIp instanceof AbstractExprRelDataType<?> exprRelDataType) {
244-
return exprRelDataType.getExprType() == ExprCoreType.IP
245-
&& typeString.getFamily() == SqlTypeFamily.CHARACTER;
246-
}
247-
return false;
248-
}
249-
250238
@Override
251239
public String getAllowedSignatures() {
252240
int min = innerTypeChecker.getOperandCountRange().getMin();
@@ -269,6 +257,31 @@ public String getAllowedSignatures() {
269257
}
270258
}
271259

260+
class PPLIPCompareTypeChecker implements PPLTypeChecker {
261+
@Override
262+
public boolean checkOperandTypes(List<RelDataType> types) {
263+
if (types.size() != 2) {
264+
return false;
265+
}
266+
RelDataType type1 = types.get(0);
267+
RelDataType type2 = types.get(1);
268+
return areIpAndStringTypes(type1, type2)
269+
|| areIpAndStringTypes(type2, type1)
270+
|| (type1 instanceof ExprIPType && type2 instanceof ExprIPType);
271+
}
272+
273+
@Override
274+
public String getAllowedSignatures() {
275+
// Will be merged with the allowed signatures of comparable type checker,
276+
// shown as [COMPARABLE_TYPE,COMPARABLE_TYPE]
277+
return "";
278+
}
279+
280+
private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
281+
return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER;
282+
}
283+
}
284+
272285
/**
273286
* Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand
274287
* belongs to its corresponding {@link SqlTypeFamily}.

core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414
import org.apache.calcite.sql.SqlOperator;
1515
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
1616
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
17-
import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker;
1817
import org.apache.calcite.sql.type.SqlOperandMetadata;
1918
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
20-
import org.apache.calcite.sql.type.SqlTypeFamily;
2119
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
2220

2321
/**
2422
* This class is created for the compatibility with {@link SqlUserDefinedFunction} constructors when
2523
* creating UDFs, so that a type checker can be passed to the constructor of {@link
2624
* SqlUserDefinedFunction} as a {@link SqlOperandMetadata}.
2725
*/
28-
public interface UDFOperandMetadata extends SqlOperandMetadata, ImplicitCastOperandTypeChecker {
26+
public interface UDFOperandMetadata extends SqlOperandMetadata {
2927
SqlOperandTypeChecker getInnerTypeChecker();
3028

3129
static UDFOperandMetadata wrap(FamilyOperandTypeChecker typeChecker) {
@@ -35,17 +33,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() {
3533
return typeChecker;
3634
}
3735

38-
@Override
39-
public boolean checkOperandTypesWithoutTypeCoercion(
40-
SqlCallBinding callBinding, boolean throwOnFailure) {
41-
return typeChecker.checkOperandTypesWithoutTypeCoercion(callBinding, throwOnFailure);
42-
}
43-
44-
@Override
45-
public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) {
46-
return typeChecker.getOperandSqlTypeFamily(iFormalOperand);
47-
}
48-
4936
@Override
5037
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
5138
// This function is not used in the current context, so we return an empty list.
@@ -89,18 +76,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() {
8976
return typeChecker;
9077
}
9178

92-
@Override
93-
public boolean checkOperandTypesWithoutTypeCoercion(
94-
SqlCallBinding callBinding, boolean throwOnFailure) {
95-
return typeChecker.checkOperandTypes(callBinding, throwOnFailure);
96-
}
97-
98-
@Override
99-
public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) {
100-
throw new IllegalStateException(
101-
"getOperandSqlTypeFamily is not supported for CompositeOperandTypeChecker");
102-
}
103-
10479
@Override
10580
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
10681
// This function is not used in the current context, so we return an empty list.
@@ -129,4 +104,40 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
129104
}
130105
};
131106
}
107+
108+
/**
109+
* A named class that serves as an identifier for IP comparator's operand metadata. It does not
110+
* implement any actual type checking logic.
111+
*/
112+
class IPOperandMetadata implements UDFOperandMetadata {
113+
@Override
114+
public SqlOperandTypeChecker getInnerTypeChecker() {
115+
return this;
116+
}
117+
118+
@Override
119+
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
120+
return List.of();
121+
}
122+
123+
@Override
124+
public List<String> paramNames() {
125+
return List.of();
126+
}
127+
128+
@Override
129+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
130+
return false;
131+
}
132+
133+
@Override
134+
public SqlOperandCountRange getOperandCountRange() {
135+
return null;
136+
}
137+
138+
@Override
139+
public String getAllowedSignatures(SqlOperator op, String opName) {
140+
return "";
141+
}
142+
}
132143
}

core/src/main/java/org/opensearch/sql/expression/function/udf/ip/LessIpFunction.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,7 @@ public SqlReturnTypeInference getReturnTypeInference() {
4444

4545
@Override
4646
public UDFOperandMetadata getOperandMetadata() {
47-
return UDFOperandMetadata.wrap(
48-
(CompositeOperandTypeChecker)
49-
OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.NULL)
50-
.or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.STRING))
51-
.or(OperandTypes.family(SqlTypeFamily.NULL, SqlTypeFamily.NULL)));
47+
return new UDFOperandMetadata.IPOperandMetadata();
5248
}
5349

5450
public static class LessImplementor implements NotNullImplementor {

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public void testComparisonWithDifferentType() {
5050
Throwable t = Assert.assertThrows(ExpressionEvaluationException.class, () -> getRelNode(ppl));
5151
verifyErrorMessageContains(
5252
t,
53-
"LESS function expects {[STRING,IP],[IP,STRING],[IP,IP],[COMPARABLE_TYPE,COMPARABLE_TYPE]},"
53+
"LESS function expects {[COMPARABLE_TYPE,COMPARABLE_TYPE]},"
5454
+ " but got [STRING,INTEGER]");
5555
}
5656

0 commit comments

Comments
 (0)