Skip to content

Commit 45f1566

Browse files
committed
Support parameter validation for if
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 281f755 commit 45f1566

2 files changed

Lines changed: 37 additions & 10 deletions

File tree

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,18 @@ void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) {
193193
// types. Verifying the composition type would require accessing a protected field in
194194
// CompositeOperandTypeChecker. If access to this field is not allowed, type checking will
195195
// be skipped, so we avoid checking the composition type here.
196-
register(functionName, createCompositeFunctionImp(operator, compositeTypeChecker, false));
196+
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, false));
197197
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
198-
register(functionName, createImplicitCastFunctionImp(operator, implicitCastTypeChecker));
198+
register(functionName, wrapWithImplicitCastTypeChecker(operator, implicitCastTypeChecker));
199199
} else if (typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
200200
// If compositeTypeChecker contains operand checkers other than family type checkers or
201201
// other than OR compositions, the function with be registered with a null type checker,
202202
// which means the function will not be type checked.
203-
register(functionName, createCompositeFunctionImp(operator, compositeTypeChecker, true));
203+
register(functionName, wrapWithCompositeTypeChecker(operator, compositeTypeChecker, true));
204204
} else if (typeChecker instanceof SameOperandTypeChecker comparableTypeChecker) {
205205
// Comparison operators like EQUAL, GREATER_THAN, LESS_THAN, etc.
206206
// SameOperandTypeCheckers like COALESCE, IFNULL, etc.
207-
register(functionName, createComparableFunctionImp(operator, comparableTypeChecker));
207+
register(functionName, wrapWithComparableTypeChecker(operator, comparableTypeChecker));
208208
} else {
209209
logger.info(
210210
"Cannot create type checker for function: {}. Will skip its type checking",
@@ -222,7 +222,7 @@ private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
222222
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
223223
}
224224

225-
private static FunctionImp createCompositeFunctionImp(
225+
private static FunctionImp wrapWithCompositeTypeChecker(
226226
SqlOperator operator,
227227
CompositeOperandTypeChecker typeChecker,
228228
boolean checkCompositionType) {
@@ -249,7 +249,7 @@ public PPLTypeChecker getTypeChecker() {
249249
};
250250
}
251251

252-
private static FunctionImp createImplicitCastFunctionImp(
252+
private static FunctionImp wrapWithImplicitCastTypeChecker(
253253
SqlOperator operator, ImplicitCastOperandTypeChecker typeChecker) {
254254
return new FunctionImp() {
255255
@Override
@@ -264,7 +264,7 @@ public PPLTypeChecker getTypeChecker() {
264264
};
265265
}
266266

267-
private static FunctionImp createComparableFunctionImp(
267+
private static FunctionImp wrapWithComparableTypeChecker(
268268
SqlOperator operator, SameOperandTypeChecker typeChecker) {
269269
return new FunctionImp() {
270270
@Override
@@ -356,7 +356,6 @@ void populate() {
356356
registerOperator(IS_NOT_NULL, SqlStdOperatorTable.IS_NOT_NULL);
357357
registerOperator(IS_PRESENT, SqlStdOperatorTable.IS_NOT_NULL);
358358
registerOperator(IS_NULL, SqlStdOperatorTable.IS_NULL);
359-
registerOperator(IF, SqlStdOperatorTable.CASE);
360359
registerOperator(IFNULL, SqlStdOperatorTable.COALESCE);
361360
registerOperator(COALESCE, SqlStdOperatorTable.COALESCE);
362361

@@ -501,14 +500,14 @@ void populate() {
501500
// checker for it.
502501
register(
503502
SUBSTRING,
504-
createCompositeFunctionImp(
503+
wrapWithCompositeTypeChecker(
505504
SqlStdOperatorTable.SUBSTRING,
506505
(CompositeOperandTypeChecker)
507506
OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER),
508507
false));
509508
register(
510509
SUBSTR,
511-
createCompositeFunctionImp(
510+
wrapWithCompositeTypeChecker(
512511
SqlStdOperatorTable.SUBSTRING,
513512
(CompositeOperandTypeChecker)
514513
OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER),
@@ -544,6 +543,14 @@ void populate() {
544543
(builder, arg) ->
545544
builder.makeLiteral(getLegacyTypeName(arg.getType(), QueryType.PPL)));
546545
register(XOR, new XOR_FUNC());
546+
// SqlStdOperatorTable.CASE.getOperandTypeChecker is null. We manually create a type checker
547+
// for it. The second and third operands are required to be of the same type. If not,
548+
// it will throw an IllegalArgumentException with information Can't find leastRestrictive type
549+
register(
550+
IF,
551+
wrapWithImplicitCastTypeChecker(
552+
SqlStdOperatorTable.CASE,
553+
OperandTypes.family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ANY)));
547554
register(
548555
NULLIF,
549556
(FunctionImp2)

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLFunctionTypeTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,24 @@ public void testSubstringWithWrongType() {
7171
"SUBSTRING function expects {[STRING,INTEGER], [STRING,INTEGER,INTEGER]}, but got"
7272
+ " [STRING,INTEGER,STRING]");
7373
}
74+
75+
@Test
76+
public void testIfWithWrongType() {
77+
getRelNode("source=EMP | eval if_name = if(EMPNO > 6, 'Jack', ENAME) | fields if_name");
78+
getRelNode("source=EMP | eval if_name = if(EMPNO > 6, EMPNO, DEPTNO) | fields if_name");
79+
String pplWrongCondition = "source=EMP | eval if_name = if(EMPNO, 1, DEPTNO) | fields if_name";
80+
Throwable t1 =
81+
Assert.assertThrows(
82+
ExpressionEvaluationException.class, () -> getRelNode(pplWrongCondition));
83+
verifyErrorMessageContains(
84+
t1, "IF function expects {[BOOLEAN,ANY,ANY]}, but got [SHORT,INTEGER,BYTE]");
85+
String pplIncompatibleType =
86+
"source=EMP | eval if_name = if(EMPNO > 6, 'Jack', 1) | fields if_name";
87+
Throwable t2 =
88+
Assert.assertThrows(IllegalArgumentException.class, () -> getRelNode(pplIncompatibleType));
89+
verifyErrorMessageContains(
90+
t2,
91+
"Cannot resolve function: IF, arguments: [BOOLEAN, VARCHAR, INTEGER], caused by: Can't find"
92+
+ " leastRestrictive type for [VARCHAR, INTEGER]");
93+
}
7494
}

0 commit comments

Comments
 (0)