Skip to content

Commit 6c3efa1

Browse files
ishaoxyyuancu
andauthored
Add compare_ip operator udfs (opensearch-project#3821)
* ip_compare operator added Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * only type checker issue left Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * fix by modifying ip.sqlTypeName from OTHER to NULL in type checker Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * fix less Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * modify the CalcitePPLFunctionTypeTest text Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * allow CalciteIPComparisonIT in CalciteNoPushdownIT Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * Modify the signature description in udf Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * fix some typing errors Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * modify the udfs for better style Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * Make IpComparisonOperators an inner enum of CompareIPFunction Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * modify registerOperator Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * modify registerOperator Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * add type checker for cidr Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * add javadoc Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> * move switch case to the implement method Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> --------- Signed-off-by: Xinyu Hao <haoxinyu@amazon.com> Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> Co-authored-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 9af1567 commit 6c3efa1

9 files changed

Lines changed: 344 additions & 87 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ public static SqlTypeName convertRelDataTypeToSqlTypeName(RelDataType type) {
121121
case EXPR_DATE -> SqlTypeName.DATE;
122122
case EXPR_TIME -> SqlTypeName.TIME;
123123
case EXPR_TIMESTAMP -> SqlTypeName.TIMESTAMP;
124-
case EXPR_IP -> SqlTypeName.VARCHAR;
124+
// EXPR_IP is mapped to SqlTypeName.OTHER since there is no
125+
// corresponding SqlTypeName in Calcite.
126+
case EXPR_IP -> SqlTypeName.OTHER;
125127
case EXPR_BINARY -> SqlTypeName.VARBINARY;
126128
default -> type.getSqlTypeName();
127129
};

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import org.opensearch.sql.expression.function.udf.datetime.WeekdayFunction;
7373
import org.opensearch.sql.expression.function.udf.datetime.YearweekFunction;
7474
import org.opensearch.sql.expression.function.udf.ip.CidrMatchFunction;
75+
import org.opensearch.sql.expression.function.udf.ip.CompareIpFunction;
7576
import org.opensearch.sql.expression.function.udf.math.CRC32Function;
7677
import org.opensearch.sql.expression.function.udf.math.ConvFunction;
7778
import org.opensearch.sql.expression.function.udf.math.DivideFunction;
@@ -103,6 +104,15 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
103104
public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2");
104105
public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH");
105106

107+
// IP comparing functions
108+
public static final SqlOperator NOT_EQUALS_IP =
109+
CompareIpFunction.notEquals().toUDF("NOT_EQUALS_IP");
110+
public static final SqlOperator EQUALS_IP = CompareIpFunction.equals().toUDF("EQUALS_IP");
111+
public static final SqlOperator GREATER_IP = CompareIpFunction.greater().toUDF("GREATER_IP");
112+
public static final SqlOperator GTE_IP = CompareIpFunction.greaterOrEquals().toUDF("GTE_IP");
113+
public static final SqlOperator LESS_IP = CompareIpFunction.less().toUDF("LESS_IP");
114+
public static final SqlOperator LTE_IP = CompareIpFunction.lessOrEquals().toUDF("LTE_IP");
115+
106116
// Condition function
107117
public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST");
108118
public static final SqlOperator LATEST = new LatestFunction().toUDF("LATEST");

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,10 @@ functionName, getActualSignature(argTypes), e.getMessage()),
458458
}
459459
StringJoiner allowedSignatures = new StringJoiner(",");
460460
for (var implement : implementList) {
461-
allowedSignatures.add(implement.getKey().typeChecker().getAllowedSignatures());
461+
String signature = implement.getKey().typeChecker().getAllowedSignatures();
462+
if (!signature.isEmpty()) {
463+
allowedSignatures.add(signature);
464+
}
462465
}
463466
throw new ExpressionEvaluationException(
464467
String.format(
@@ -481,40 +484,70 @@ private abstract static class AbstractBuilder {
481484
/** Maps an operator to an implementation. */
482485
abstract void register(BuiltinFunctionName functionName, FunctionImp functionImp);
483486

484-
void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) {
485-
SqlOperandTypeChecker typeChecker;
486-
if (operator instanceof SqlUserDefinedFunction udfOperator) {
487-
typeChecker = extractTypeCheckerFromUDF(udfOperator);
488-
} else {
489-
typeChecker = operator.getOperandTypeChecker();
490-
}
487+
/**
488+
* Register one or multiple operators under a single function name. This allows function
489+
* overloading based on operand types.
490+
*
491+
* <p>When a function is called, the system will try each registered operator in sequence,
492+
* checking if the provided arguments match the operator's type requirements. The first operator
493+
* whose type checker accepts the arguments will be used to execute the function.
494+
*
495+
* @param functionName the built-in function name under which to register the operators
496+
* @param operators the operators to associate with this function name, tried in sequence until
497+
* one matches the argument types during resolution
498+
*/
499+
public void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
500+
for (SqlOperator operator : operators) {
501+
SqlOperandTypeChecker typeChecker;
502+
if (operator instanceof SqlUserDefinedFunction udfOperator) {
503+
typeChecker = extractTypeCheckerFromUDF(udfOperator);
504+
} else {
505+
typeChecker = operator.getOperandTypeChecker();
506+
}
491507

492-
// Only the composite operand type checker for UDFs are concerned here.
493-
if (operator instanceof SqlUserDefinedFunction
494-
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
495-
// UDFs implement their own composite type checkers, which always use OR logic for argument
496-
// types. Verifying the composition type would require accessing a protected field in
497-
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
498-
// be skipped, so we avoid checking the composition type here.
499-
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
500-
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
501-
register(functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
502-
} else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
503-
// If compositeTypeChecker contains operand checkers other than family type checkers or
504-
// other than OR compositions, the function with be registered with a null type checker,
505-
// which means the function will not be type checked.
506-
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
507-
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
508-
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
509-
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
510-
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
511-
} else {
512-
logger.info(
513-
"Cannot create type checker for function: {}. Will skip its type checking",
514-
functionName);
515-
register(
516-
functionName,
517-
(RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
508+
// Only the composite operand type checker for UDFs are concerned here.
509+
if (operator instanceof SqlUserDefinedFunction
510+
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
511+
// UDFs implement their own composite type checkers, which always use OR logic for
512+
// argument
513+
// types. Verifying the composition type would require accessing a protected field in
514+
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
515+
// be skipped, so we avoid checking the composition type here.
516+
register(
517+
functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
518+
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
519+
register(
520+
functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
521+
} else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
522+
// If compositeTypeChecker contains operand checkers other than family type checkers or
523+
// other than OR compositions, the function with be registered with a null type checker,
524+
// which means the function will not be type checked.
525+
register(
526+
functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
527+
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
528+
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
529+
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
530+
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
531+
} else if (typeChecker instanceof UDFOperandMetadata.IPOperandMetadata) {
532+
register(
533+
functionName,
534+
createFunctionImpWithTypeChecker(
535+
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
536+
new PPLTypeChecker.PPLIPCompareTypeChecker()));
537+
} else if (typeChecker instanceof UDFOperandMetadata.CidrOperandMetadata) {
538+
register(
539+
functionName,
540+
createFunctionImpWithTypeChecker(
541+
(builder, arg1, arg2) -> builder.makeCall(operator, arg1, arg2),
542+
new PPLTypeChecker.PPLCidrTypeChecker()));
543+
} else {
544+
logger.info(
545+
"Cannot create type checker for function: {}. Will skip its type checking",
546+
functionName);
547+
register(
548+
functionName,
549+
(RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
550+
}
518551
}
519552
}
520553

@@ -622,16 +655,18 @@ public PPLTypeChecker getTypeChecker() {
622655
}
623656

624657
void populate() {
658+
// register operators for comparison
659+
registerOperator(NOTEQUAL, PPLBuiltinOperators.NOT_EQUALS_IP, SqlStdOperatorTable.NOT_EQUALS);
660+
registerOperator(EQUAL, PPLBuiltinOperators.EQUALS_IP, SqlStdOperatorTable.EQUALS);
661+
registerOperator(GREATER, PPLBuiltinOperators.GREATER_IP, SqlStdOperatorTable.GREATER_THAN);
662+
registerOperator(GTE, PPLBuiltinOperators.GTE_IP, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
663+
registerOperator(LESS, PPLBuiltinOperators.LESS_IP, SqlStdOperatorTable.LESS_THAN);
664+
registerOperator(LTE, PPLBuiltinOperators.LTE_IP, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);
665+
625666
// Register std operator
626667
registerOperator(AND, SqlStdOperatorTable.AND);
627668
registerOperator(OR, SqlStdOperatorTable.OR);
628669
registerOperator(NOT, SqlStdOperatorTable.NOT);
629-
registerOperator(NOTEQUAL, SqlStdOperatorTable.NOT_EQUALS);
630-
registerOperator(EQUAL, SqlStdOperatorTable.EQUALS);
631-
registerOperator(GREATER, SqlStdOperatorTable.GREATER_THAN);
632-
registerOperator(GTE, SqlStdOperatorTable.GREATER_THAN_OR_EQUAL);
633-
registerOperator(LESS, SqlStdOperatorTable.LESS_THAN);
634-
registerOperator(LTE, SqlStdOperatorTable.LESS_THAN_OR_EQUAL);
635670
registerOperator(ADD, SqlStdOperatorTable.PLUS);
636671
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
637672
registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY);

core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import org.apache.calcite.sql.type.SqlTypeFamily;
2424
import org.apache.calcite.sql.type.SqlTypeName;
2525
import org.apache.calcite.sql.type.SqlTypeUtil;
26-
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
26+
import org.opensearch.sql.calcite.type.ExprIPType;
2727
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
2828
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
2929
import org.opensearch.sql.data.type.ExprCoreType;
@@ -215,10 +215,6 @@ public boolean checkOperandTypes(List<RelDataType> types) {
215215
RelDataType type_l = types.get(i);
216216
RelDataType type_r = types.get(i + 1);
217217
if (!SqlTypeUtil.isComparable(type_l, type_r)) {
218-
if (areIpAndStringTypes(type_l, type_r) || areIpAndStringTypes(type_r, type_l)) {
219-
// Allow IP and string comparison
220-
continue;
221-
}
222218
return false;
223219
}
224220
// Disallow coercing between strings and numeric, boolean
@@ -239,14 +235,6 @@ private static boolean cannotConvertStringInCompare(SqlTypeFamily typeFamily) {
239235
};
240236
}
241237

242-
private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
243-
if (typeIp instanceof AbstractExprRelDataType<?> exprRelDataType) {
244-
return exprRelDataType.getExprType() == ExprCoreType.IP
245-
&& typeString.getFamily() == SqlTypeFamily.CHARACTER;
246-
}
247-
return false;
248-
}
249-
250238
@Override
251239
public String getAllowedSignatures() {
252240
int min = innerTypeChecker.getOperandCountRange().getMin();
@@ -269,6 +257,53 @@ public String getAllowedSignatures() {
269257
}
270258
}
271259

260+
class PPLIPCompareTypeChecker implements PPLTypeChecker {
261+
@Override
262+
public boolean checkOperandTypes(List<RelDataType> types) {
263+
if (types.size() != 2) {
264+
return false;
265+
}
266+
RelDataType type1 = types.get(0);
267+
RelDataType type2 = types.get(1);
268+
return areIpAndStringTypes(type1, type2)
269+
|| areIpAndStringTypes(type2, type1)
270+
|| (type1 instanceof ExprIPType && type2 instanceof ExprIPType);
271+
}
272+
273+
@Override
274+
public String getAllowedSignatures() {
275+
// Will be merged with the allowed signatures of comparable type checker,
276+
// shown as [COMPARABLE_TYPE,COMPARABLE_TYPE]
277+
return "";
278+
}
279+
280+
private static boolean areIpAndStringTypes(RelDataType typeIp, RelDataType typeString) {
281+
return typeIp instanceof ExprIPType && typeString.getFamily() == SqlTypeFamily.CHARACTER;
282+
}
283+
}
284+
285+
class PPLCidrTypeChecker implements PPLTypeChecker {
286+
@Override
287+
public boolean checkOperandTypes(List<RelDataType> types) {
288+
if (types.size() != 2) {
289+
return false;
290+
}
291+
RelDataType type1 = types.get(0);
292+
RelDataType type2 = types.get(1);
293+
294+
// accept (STRING, STRING) or (IP, STRING)
295+
if (type2.getFamily() != SqlTypeFamily.CHARACTER) {
296+
return false;
297+
}
298+
return type1 instanceof ExprIPType || type1.getFamily() == SqlTypeFamily.CHARACTER;
299+
}
300+
301+
@Override
302+
public String getAllowedSignatures() {
303+
return "[STRING,STRING],[IP,STRING]";
304+
}
305+
}
306+
272307
/**
273308
* Creates a {@link PPLFamilyTypeChecker} with a fixed operand count, validating that each operand
274309
* belongs to its corresponding {@link SqlTypeFamily}.

core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414
import org.apache.calcite.sql.SqlOperator;
1515
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
1616
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
17-
import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker;
1817
import org.apache.calcite.sql.type.SqlOperandMetadata;
1918
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
20-
import org.apache.calcite.sql.type.SqlTypeFamily;
2119
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
2220

2321
/**
2422
* This class is created for the compatibility with {@link SqlUserDefinedFunction} constructors when
2523
* creating UDFs, so that a type checker can be passed to the constructor of {@link
2624
* SqlUserDefinedFunction} as a {@link SqlOperandMetadata}.
2725
*/
28-
public interface UDFOperandMetadata extends SqlOperandMetadata, ImplicitCastOperandTypeChecker {
26+
public interface UDFOperandMetadata extends SqlOperandMetadata {
2927
SqlOperandTypeChecker getInnerTypeChecker();
3028

3129
static UDFOperandMetadata wrap(FamilyOperandTypeChecker typeChecker) {
@@ -35,17 +33,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() {
3533
return typeChecker;
3634
}
3735

38-
@Override
39-
public boolean checkOperandTypesWithoutTypeCoercion(
40-
SqlCallBinding callBinding, boolean throwOnFailure) {
41-
return typeChecker.checkOperandTypesWithoutTypeCoercion(callBinding, throwOnFailure);
42-
}
43-
44-
@Override
45-
public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) {
46-
return typeChecker.getOperandSqlTypeFamily(iFormalOperand);
47-
}
48-
4936
@Override
5037
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
5138
// This function is not used in the current context, so we return an empty list.
@@ -89,18 +76,6 @@ public SqlOperandTypeChecker getInnerTypeChecker() {
8976
return typeChecker;
9077
}
9178

92-
@Override
93-
public boolean checkOperandTypesWithoutTypeCoercion(
94-
SqlCallBinding callBinding, boolean throwOnFailure) {
95-
return typeChecker.checkOperandTypes(callBinding, throwOnFailure);
96-
}
97-
98-
@Override
99-
public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) {
100-
throw new IllegalStateException(
101-
"getOperandSqlTypeFamily is not supported for CompositeOperandTypeChecker");
102-
}
103-
10479
@Override
10580
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
10681
// This function is not used in the current context, so we return an empty list.
@@ -129,4 +104,76 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
129104
}
130105
};
131106
}
107+
108+
/**
109+
* A named class that serves as an identifier for IP comparator's operand metadata. It does not
110+
* implement any actual type checking logic.
111+
*/
112+
class IPOperandMetadata implements UDFOperandMetadata {
113+
@Override
114+
public SqlOperandTypeChecker getInnerTypeChecker() {
115+
return this;
116+
}
117+
118+
@Override
119+
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
120+
return List.of();
121+
}
122+
123+
@Override
124+
public List<String> paramNames() {
125+
return List.of();
126+
}
127+
128+
@Override
129+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
130+
return false;
131+
}
132+
133+
@Override
134+
public SqlOperandCountRange getOperandCountRange() {
135+
return null;
136+
}
137+
138+
@Override
139+
public String getAllowedSignatures(SqlOperator op, String opName) {
140+
return "";
141+
}
142+
}
143+
144+
/**
145+
* A named class that serves as an identifier for cidr's operand metadata. It does not implement
146+
* any actual type checking logic.
147+
*/
148+
class CidrOperandMetadata implements UDFOperandMetadata {
149+
@Override
150+
public SqlOperandTypeChecker getInnerTypeChecker() {
151+
return this;
152+
}
153+
154+
@Override
155+
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
156+
return List.of();
157+
}
158+
159+
@Override
160+
public List<String> paramNames() {
161+
return List.of();
162+
}
163+
164+
@Override
165+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
166+
return false;
167+
}
168+
169+
@Override
170+
public SqlOperandCountRange getOperandCountRange() {
171+
return null;
172+
}
173+
174+
@Override
175+
public String getAllowedSignatures(SqlOperator op, String opName) {
176+
return "";
177+
}
178+
}
132179
}

0 commit comments

Comments
 (0)