66package org .opensearch .sql .expression .function .CollectionUDF ;
77
88import static org .apache .calcite .sql .type .SqlTypeUtil .createArrayType ;
9+ import static org .apache .calcite .util .Static .RESOURCE ;
910import static org .opensearch .sql .expression .function .CollectionUDF .LambdaUtils .transferLambdaOutputToTargetType ;
1011
1112import java .util .ArrayList ;
1213import java .util .List ;
14+ import java .util .stream .IntStream ;
1315import org .apache .calcite .adapter .enumerable .NotNullImplementor ;
1416import org .apache .calcite .adapter .enumerable .NullPolicy ;
1517import org .apache .calcite .adapter .enumerable .RexToLixTranslator ;
1921import org .apache .calcite .rel .type .RelDataType ;
2022import org .apache .calcite .rel .type .RelDataTypeFactory ;
2123import org .apache .calcite .rex .RexCall ;
24+ import org .apache .calcite .sql .SqlCallBinding ;
25+ import org .apache .calcite .sql .SqlNode ;
26+ import org .apache .calcite .sql .SqlOperandCountRange ;
27+ import org .apache .calcite .sql .SqlOperator ;
28+ import org .apache .calcite .sql .SqlUtil ;
2229import org .apache .calcite .sql .type .ArraySqlType ;
23- import org .apache .calcite .sql .type .OperandTypes ;
30+ import org .apache .calcite .sql .type .FamilyOperandTypeChecker ;
31+ import org .apache .calcite .sql .type .SqlOperandCountRanges ;
2432import org .apache .calcite .sql .type .SqlReturnTypeInference ;
33+ import org .apache .calcite .sql .type .SqlSingleOperandTypeChecker ;
2534import org .apache .calcite .sql .type .SqlTypeFamily ;
2635import org .apache .calcite .sql .type .SqlTypeName ;
36+ import org .apache .calcite .sql .type .SqlTypeUtil ;
2737import org .opensearch .sql .expression .function .ImplementorUDF ;
2838import org .opensearch .sql .expression .function .UDFOperandMetadata ;
2939
@@ -50,8 +60,94 @@ public SqlReturnTypeInference getReturnTypeInference() {
5060
5161 @ Override
5262 public UDFOperandMetadata getOperandMetadata () {
63+ // Only checks the first two arguments as it allows arbitrary number of arguments to follow them
5364 return UDFOperandMetadata .wrap (
54- OperandTypes .family (SqlTypeFamily .ARRAY , SqlTypeFamily .FUNCTION ));
65+ new SqlSingleOperandTypeChecker () {
66+ private static final List <SqlTypeFamily > families =
67+ List .of (SqlTypeFamily .ARRAY , SqlTypeFamily .FUNCTION );
68+
69+ /**
70+ * Copied from {@link FamilyOperandTypeChecker#checkSingleOperandType(SqlCallBinding
71+ * callBinding, SqlNode node, int iFormalOperand, boolean throwOnFailure)}
72+ */
73+ @ Override
74+ public boolean checkSingleOperandType (
75+ SqlCallBinding callBinding ,
76+ SqlNode operand ,
77+ int iFormalOperand ,
78+ boolean throwOnFailure ) {
79+ // Do not check types after the second operands
80+ if (iFormalOperand > 1 ) {
81+ return true ;
82+ }
83+ SqlTypeFamily family = families .get (iFormalOperand );
84+ switch (family ) {
85+ case ANY :
86+ final RelDataType type = SqlTypeUtil .deriveType (callBinding , operand );
87+ SqlTypeName typeName = type .getSqlTypeName ();
88+
89+ if (typeName == SqlTypeName .CURSOR ) {
90+ // We do not allow CURSOR operands, even for ANY
91+ if (throwOnFailure ) {
92+ throw callBinding .newValidationSignatureError ();
93+ }
94+ return false ;
95+ }
96+ // fall through
97+ case IGNORE :
98+ // no need to check
99+ return true ;
100+ default :
101+ break ;
102+ }
103+ if (SqlUtil .isNullLiteral (operand , false )) {
104+ if (callBinding .isTypeCoercionEnabled ()) {
105+ return true ;
106+ } else if (throwOnFailure ) {
107+ throw callBinding
108+ .getValidator ()
109+ .newValidationError (operand , RESOURCE .nullIllegal ());
110+ } else {
111+ return false ;
112+ }
113+ }
114+ RelDataType type = SqlTypeUtil .deriveType (callBinding , operand );
115+ SqlTypeName typeName = type .getSqlTypeName ();
116+
117+ // Pass type checking for operators if it's of type 'ANY'.
118+ if (typeName .getFamily () == SqlTypeFamily .ANY ) {
119+ return true ;
120+ }
121+
122+ if (!family .getTypeNames ().contains (typeName )) {
123+ if (throwOnFailure ) {
124+ throw callBinding .newValidationSignatureError ();
125+ }
126+ return false ;
127+ }
128+ return true ;
129+ }
130+
131+ @ Override
132+ public boolean checkOperandTypes (SqlCallBinding callBinding , boolean throwOnFailure ) {
133+ if (!getOperandCountRange ().isValidCount (callBinding .getOperandCount ())) {
134+ return false ;
135+ }
136+ return IntStream .range (0 , 2 )
137+ .allMatch (
138+ i -> checkSingleOperandType (callBinding , callBinding .operand (i ), i , false ));
139+ }
140+
141+ @ Override
142+ public SqlOperandCountRange getOperandCountRange () {
143+ return SqlOperandCountRanges .from (2 );
144+ }
145+
146+ @ Override
147+ public String getAllowedSignatures (SqlOperator op , String opName ) {
148+ return "<ARRAY, FUNCTION, ...ANY OTHER ARGUMENTS>" ;
149+ }
150+ });
55151 }
56152
57153 public static class TransformImplementor implements NotNullImplementor {
0 commit comments