Skip to content

Commit d19ba53

Browse files
committed
Update the type checker of transform function to allow arbitrary additional arguments (1931/2069)
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 485c038 commit d19ba53

1 file changed

Lines changed: 98 additions & 2 deletions

File tree

core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/TransformFunctionImpl.java

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
package org.opensearch.sql.expression.function.CollectionUDF;
77

88
import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType;
9+
import static org.apache.calcite.util.Static.RESOURCE;
910
import static org.opensearch.sql.expression.function.CollectionUDF.LambdaUtils.transferLambdaOutputToTargetType;
1011

1112
import java.util.ArrayList;
1213
import java.util.List;
14+
import java.util.stream.IntStream;
1315
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
1416
import org.apache.calcite.adapter.enumerable.NullPolicy;
1517
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
@@ -19,11 +21,19 @@
1921
import org.apache.calcite.rel.type.RelDataType;
2022
import org.apache.calcite.rel.type.RelDataTypeFactory;
2123
import org.apache.calcite.rex.RexCall;
24+
import org.apache.calcite.sql.SqlCallBinding;
25+
import org.apache.calcite.sql.SqlNode;
26+
import org.apache.calcite.sql.SqlOperandCountRange;
27+
import org.apache.calcite.sql.SqlOperator;
28+
import org.apache.calcite.sql.SqlUtil;
2229
import org.apache.calcite.sql.type.ArraySqlType;
23-
import org.apache.calcite.sql.type.OperandTypes;
30+
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
31+
import org.apache.calcite.sql.type.SqlOperandCountRanges;
2432
import org.apache.calcite.sql.type.SqlReturnTypeInference;
33+
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
2534
import org.apache.calcite.sql.type.SqlTypeFamily;
2635
import org.apache.calcite.sql.type.SqlTypeName;
36+
import org.apache.calcite.sql.type.SqlTypeUtil;
2737
import org.opensearch.sql.expression.function.ImplementorUDF;
2838
import org.opensearch.sql.expression.function.UDFOperandMetadata;
2939

@@ -50,8 +60,94 @@ public SqlReturnTypeInference getReturnTypeInference() {
5060

5161
@Override
5262
public UDFOperandMetadata getOperandMetadata() {
63+
// Only checks the first two arguments as it allows arbitrary number of arguments to follow them
5364
return UDFOperandMetadata.wrap(
54-
OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.FUNCTION));
65+
new SqlSingleOperandTypeChecker() {
66+
private static final List<SqlTypeFamily> families =
67+
List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.FUNCTION);
68+
69+
/**
70+
* Copied from {@link FamilyOperandTypeChecker#checkSingleOperandType(SqlCallBinding
71+
* callBinding, SqlNode node, int iFormalOperand, boolean throwOnFailure)}
72+
*/
73+
@Override
74+
public boolean checkSingleOperandType(
75+
SqlCallBinding callBinding,
76+
SqlNode operand,
77+
int iFormalOperand,
78+
boolean throwOnFailure) {
79+
// Do not check types after the second operands
80+
if (iFormalOperand > 1) {
81+
return true;
82+
}
83+
SqlTypeFamily family = families.get(iFormalOperand);
84+
switch (family) {
85+
case ANY:
86+
final RelDataType type = SqlTypeUtil.deriveType(callBinding, operand);
87+
SqlTypeName typeName = type.getSqlTypeName();
88+
89+
if (typeName == SqlTypeName.CURSOR) {
90+
// We do not allow CURSOR operands, even for ANY
91+
if (throwOnFailure) {
92+
throw callBinding.newValidationSignatureError();
93+
}
94+
return false;
95+
}
96+
// fall through
97+
case IGNORE:
98+
// no need to check
99+
return true;
100+
default:
101+
break;
102+
}
103+
if (SqlUtil.isNullLiteral(operand, false)) {
104+
if (callBinding.isTypeCoercionEnabled()) {
105+
return true;
106+
} else if (throwOnFailure) {
107+
throw callBinding
108+
.getValidator()
109+
.newValidationError(operand, RESOURCE.nullIllegal());
110+
} else {
111+
return false;
112+
}
113+
}
114+
RelDataType type = SqlTypeUtil.deriveType(callBinding, operand);
115+
SqlTypeName typeName = type.getSqlTypeName();
116+
117+
// Pass type checking for operators if it's of type 'ANY'.
118+
if (typeName.getFamily() == SqlTypeFamily.ANY) {
119+
return true;
120+
}
121+
122+
if (!family.getTypeNames().contains(typeName)) {
123+
if (throwOnFailure) {
124+
throw callBinding.newValidationSignatureError();
125+
}
126+
return false;
127+
}
128+
return true;
129+
}
130+
131+
@Override
132+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
133+
if (!getOperandCountRange().isValidCount(callBinding.getOperandCount())) {
134+
return false;
135+
}
136+
return IntStream.range(0, 2)
137+
.allMatch(
138+
i -> checkSingleOperandType(callBinding, callBinding.operand(i), i, false));
139+
}
140+
141+
@Override
142+
public SqlOperandCountRange getOperandCountRange() {
143+
return SqlOperandCountRanges.from(2);
144+
}
145+
146+
@Override
147+
public String getAllowedSignatures(SqlOperator op, String opName) {
148+
return "<ARRAY, FUNCTION, ...ANY OTHER ARGUMENTS>";
149+
}
150+
});
55151
}
56152

57153
public static class TransformImplementor implements NotNullImplementor {

0 commit comments

Comments
 (0)