Skip to content

Commit 97ad5d7

Browse files
committed
Define operand type inference for transform function
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent cb78ea7 commit 97ad5d7

3 files changed

Lines changed: 56 additions & 11 deletions

File tree

core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/TransformFunctionImpl.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import java.util.ArrayList;
1313
import java.util.List;
14+
import java.util.Map;
1415
import java.util.stream.IntStream;
1516
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
1617
import org.apache.calcite.adapter.enumerable.NullPolicy;
@@ -22,18 +23,22 @@
2223
import org.apache.calcite.rel.type.RelDataTypeFactory;
2324
import org.apache.calcite.rex.RexCall;
2425
import org.apache.calcite.sql.SqlCallBinding;
26+
import org.apache.calcite.sql.SqlLambda;
2527
import org.apache.calcite.sql.SqlNode;
2628
import org.apache.calcite.sql.SqlOperandCountRange;
2729
import org.apache.calcite.sql.SqlOperator;
2830
import org.apache.calcite.sql.SqlUtil;
2931
import org.apache.calcite.sql.type.ArraySqlType;
3032
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
3133
import org.apache.calcite.sql.type.SqlOperandCountRanges;
34+
import org.apache.calcite.sql.type.SqlOperandTypeInference;
3235
import org.apache.calcite.sql.type.SqlReturnTypeInference;
3336
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
3437
import org.apache.calcite.sql.type.SqlTypeFamily;
3538
import org.apache.calcite.sql.type.SqlTypeName;
3639
import org.apache.calcite.sql.type.SqlTypeUtil;
40+
import org.apache.calcite.sql.validate.SqlLambdaScope;
41+
import org.apache.calcite.sql.validate.SqlValidator;
3742
import org.opensearch.sql.expression.function.ImplementorUDF;
3843
import org.opensearch.sql.expression.function.UDFOperandMetadata;
3944

@@ -58,6 +63,35 @@ public SqlReturnTypeInference getReturnTypeInference() {
5863
};
5964
}
6065

66+
@Override
67+
public SqlOperandTypeInference getOperandTypeInference() {
68+
// Pass the element type of TRANSFORM's first argument as the type of the first argument of the
69+
// lambda function.
70+
return (callBinding, returnType, operandTypes) -> {
71+
RelDataType arrayType = callBinding.getOperandType(0);
72+
operandTypes[0] = arrayType;
73+
if (callBinding.operand(1) instanceof SqlLambda lambdaNode) {
74+
SqlValidator validator = callBinding.getValidator();
75+
if (validator.getLambdaScope(lambdaNode) instanceof SqlLambdaScope lambdaScope) {
76+
RelDataType elementType = arrayType.getComponentType();
77+
Map<String, RelDataType> paramTypes = lambdaScope.getParameterTypes();
78+
List<SqlNode> params = lambdaNode.getParameters();
79+
// First parameter: array element type. Leave it as is (typically ANY) if element type is
80+
// null
81+
if (!params.isEmpty() && elementType != null) {
82+
paramTypes.put(params.get(0).toString(), elementType);
83+
}
84+
// Second parameter (if exists): INTEGER (for index)
85+
if (params.size() > 1) {
86+
RelDataType intType = callBinding.getTypeFactory().createSqlType(SqlTypeName.INTEGER);
87+
paramTypes.put(params.get(1).toString(), intType);
88+
}
89+
operandTypes[1] = SqlTypeUtil.deriveType(callBinding, lambdaNode);
90+
}
91+
}
92+
};
93+
}
94+
6195
@Override
6296
public UDFOperandMetadata getOperandMetadata() {
6397
// Only checks the first two arguments as it allows arbitrary number of arguments to follow them

core/src/main/java/org/opensearch/sql/expression/function/UserDefinedFunctionBuilder.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.calcite.sql.SqlOperandCountRange;
1313
import org.apache.calcite.sql.parser.SqlParserPos;
1414
import org.apache.calcite.sql.type.InferTypes;
15+
import org.apache.calcite.sql.type.SqlOperandTypeInference;
1516
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1617
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
1718

@@ -37,6 +38,16 @@ default SqlKind getKind() {
3738
return SqlKind.OTHER_FUNCTION;
3839
}
3940

41+
/**
42+
* Define the strategy to infer unknown types of the operands of an operator call.
43+
*
44+
* @return SqlOperandTypeInference the specified operand type inference. Default to {@link
45+
* InferTypes#ANY_NULLABLE}
46+
*/
47+
default SqlOperandTypeInference getOperandTypeInference() {
48+
return InferTypes.ANY_NULLABLE;
49+
}
50+
4051
default SqlUserDefinedFunction toUDF(String functionName) {
4152
return toUDF(functionName, true);
4253
}
@@ -57,7 +68,7 @@ default SqlUserDefinedFunction toUDF(String functionName, boolean isDeterministi
5768
udfLtrimIdentifier,
5869
getKind(),
5970
getReturnTypeInference(),
60-
InferTypes.ANY_NULLABLE,
71+
getOperandTypeInference(),
6172
getOperandMetadata(),
6273
getFunction()) {
6374
@Override

docs/user/ppl/functions/collection.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ Usage: `transform(array, function)` transform the element of array one by one us
162162
Argument type: array:ARRAY, function:LAMBDA
163163
Return type: ARRAY
164164
Example
165-
<!-- TODO: To be fixed with https://github.com/opensearch-project/sql/issues/4972 -->
166-
```ppl ignore
165+
166+
```ppl
167167
source=people
168168
| eval array = array(1, -2, 3), result = transform(array, x -> x + 2)
169169
| fields result
@@ -180,8 +180,8 @@ fetched rows / total rows = 1/1
180180
| [3,0,5] |
181181
+---------+
182182
```
183-
<!-- TODO: To be fixed with https://github.com/opensearch-project/sql/issues/4972 -->
184-
```ppl ignore
183+
184+
```ppl
185185
source=people
186186
| eval array = array(1, -2, 3), result = transform(array, (x, i) -> x + i)
187187
| fields result
@@ -814,8 +814,8 @@ Usage: mvmap(array, expression) iterates over each element of a multivalue array
814814
Argument type: array: ARRAY, expression: EXPRESSION
815815
Return type: ARRAY
816816
Example
817-
<!-- TODO: To be fixed with https://github.com/opensearch-project/sql/issues/4972 -->
818-
```ppl ignore
817+
818+
```ppl
819819
source=people
820820
| eval array = array(1, 2, 3), result = mvmap(array, array * 10)
821821
| fields result
@@ -832,8 +832,8 @@ fetched rows / total rows = 1/1
832832
| [10,20,30] |
833833
+------------+
834834
```
835-
<!-- TODO: To be fixed with https://github.com/opensearch-project/sql/issues/4972 -->
836-
```ppl ignore
835+
836+
```ppl
837837
source=people
838838
| eval array = array(1, 2, 3), result = mvmap(array, array + 5)
839839
| fields result
@@ -854,8 +854,8 @@ fetched rows / total rows = 1/1
854854
Note: For nested expressions like ``mvmap(mvindex(arr, 1, 3), arr * 2)``, the field name (``arr``) is extracted from the first argument and must match the field referenced in the expression.
855855

856856
The expression can also reference other single-value fields:
857-
<!-- TODO: To be fixed with https://github.com/opensearch-project/sql/issues/4972 -->
858-
```ppl ignore
857+
858+
```ppl
859859
source=people
860860
| eval array = array(1, 2, 3), multiplier = 10, result = mvmap(array, array * multiplier)
861861
| fields result

0 commit comments

Comments
 (0)