1111
1212import java .util .ArrayList ;
1313import java .util .List ;
14+ import java .util .Map ;
1415import java .util .stream .IntStream ;
1516import org .apache .calcite .adapter .enumerable .NotNullImplementor ;
1617import org .apache .calcite .adapter .enumerable .NullPolicy ;
2223import org .apache .calcite .rel .type .RelDataTypeFactory ;
2324import org .apache .calcite .rex .RexCall ;
2425import org .apache .calcite .sql .SqlCallBinding ;
26+ import org .apache .calcite .sql .SqlLambda ;
2527import org .apache .calcite .sql .SqlNode ;
2628import org .apache .calcite .sql .SqlOperandCountRange ;
2729import org .apache .calcite .sql .SqlOperator ;
2830import org .apache .calcite .sql .SqlUtil ;
2931import org .apache .calcite .sql .type .ArraySqlType ;
3032import org .apache .calcite .sql .type .FamilyOperandTypeChecker ;
3133import org .apache .calcite .sql .type .SqlOperandCountRanges ;
34+ import org .apache .calcite .sql .type .SqlOperandTypeInference ;
3235import org .apache .calcite .sql .type .SqlReturnTypeInference ;
3336import org .apache .calcite .sql .type .SqlSingleOperandTypeChecker ;
3437import org .apache .calcite .sql .type .SqlTypeFamily ;
3538import org .apache .calcite .sql .type .SqlTypeName ;
3639import org .apache .calcite .sql .type .SqlTypeUtil ;
40+ import org .apache .calcite .sql .validate .SqlLambdaScope ;
41+ import org .apache .calcite .sql .validate .SqlValidator ;
3742import org .opensearch .sql .expression .function .ImplementorUDF ;
3843import org .opensearch .sql .expression .function .UDFOperandMetadata ;
3944
@@ -58,6 +63,35 @@ public SqlReturnTypeInference getReturnTypeInference() {
5863 };
5964 }
6065
66+ @ Override
67+ public SqlOperandTypeInference getOperandTypeInference () {
68+ // Pass the element type of TRANSFORM's first argument as the type of the first argument of the
69+ // lambda function.
70+ return (callBinding , returnType , operandTypes ) -> {
71+ RelDataType arrayType = callBinding .getOperandType (0 );
72+ operandTypes [0 ] = arrayType ;
73+ if (callBinding .operand (1 ) instanceof SqlLambda lambdaNode ) {
74+ SqlValidator validator = callBinding .getValidator ();
75+ if (validator .getLambdaScope (lambdaNode ) instanceof SqlLambdaScope lambdaScope ) {
76+ RelDataType elementType = arrayType .getComponentType ();
77+ Map <String , RelDataType > paramTypes = lambdaScope .getParameterTypes ();
78+ List <SqlNode > params = lambdaNode .getParameters ();
79+ // First parameter: array element type. Leave it as is (typically ANY) if element type is
80+ // null
81+ if (!params .isEmpty () && elementType != null ) {
82+ paramTypes .put (params .get (0 ).toString (), elementType );
83+ }
84+ // Second parameter (if exists): INTEGER (for index)
85+ if (params .size () > 1 ) {
86+ RelDataType intType = callBinding .getTypeFactory ().createSqlType (SqlTypeName .INTEGER );
87+ paramTypes .put (params .get (1 ).toString (), intType );
88+ }
89+ operandTypes [1 ] = SqlTypeUtil .deriveType (callBinding , lambdaNode );
90+ }
91+ }
92+ };
93+ }
94+
6195 @ Override
6296 public UDFOperandMetadata getOperandMetadata () {
6397 // Only checks the first two arguments as it allows arbitrary number of arguments to follow them
0 commit comments