Skip to content

Commit 70d6a7a

Browse files
committed
fix reduce and add doc
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent 17a19d3 commit 70d6a7a

3 files changed

Lines changed: 32 additions & 25 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Map;
2121
import java.util.stream.Collectors;
2222
import java.util.stream.IntStream;
23+
import javax.annotation.Nullable;
2324
import lombok.RequiredArgsConstructor;
2425
import org.apache.calcite.rel.RelNode;
2526
import org.apache.calcite.rel.type.RelDataType;
@@ -381,15 +382,17 @@ public CalcitePlanContext prepareLambdaContext(
381382
CalcitePlanContext context,
382383
LambdaFunction node,
383384
List<RexNode> previousArgument,
384-
String functionName) {
385+
String functionName,
386+
@Nullable RelDataType defaultTypeForReduceAcc) {
385387
try {
386388
CalcitePlanContext lambdaContext = context.clone();
387389
List<RelDataType> candidateType = new ArrayList<>();
388390
candidateType.add(
389391
((ArraySqlType) previousArgument.get(0).getType())
390392
.getComponentType()); // The first argument should be array type
391393
candidateType.addAll(previousArgument.stream().skip(1).map(RexNode::getType).toList());
392-
candidateType = modifyLambdaTypeByFunction(functionName, candidateType);
394+
candidateType =
395+
modifyLambdaTypeByFunction(functionName, candidateType, defaultTypeForReduceAcc);
393396
List<QualifiedName> argNames = node.getFuncArgs();
394397
Map<String, RexLambdaRef> lambdaTypes = new HashMap<>();
395398
int candidateIndex;
@@ -421,12 +424,19 @@ public CalcitePlanContext prepareLambdaContext(
421424
* reduce has special logic.
422425
*/
423426
private List<RelDataType> modifyLambdaTypeByFunction(
424-
String functionName, List<RelDataType> originalType) {
427+
String functionName,
428+
List<RelDataType> originalType,
429+
@Nullable RelDataType defaultTypeForReduceAcc) {
425430
switch (functionName.toUpperCase(Locale.ROOT)) {
426431
case "REDUCE": // For reduce case, the first type is acc should be any since it is the output
427432
// of accumulator lambda function
428433
if (originalType.size() == 2) {
429-
return List.of(originalType.get(1), originalType.get(0));
434+
if (defaultTypeForReduceAcc == null
435+
|| defaultTypeForReduceAcc.equals(originalType.get(1))) {
436+
return List.of(originalType.get(1), originalType.get(0));
437+
}
438+
return List.of(TYPE_FACTORY.createSqlType(SqlTypeName.ANY, true), originalType.get(0));
439+
430440
} else {
431441
return List.of(originalType.get(2));
432442
}
@@ -442,8 +452,20 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
442452
for (UnresolvedExpression arg : args) {
443453
if (arg instanceof LambdaFunction) {
444454
CalcitePlanContext lambdaContext =
445-
prepareLambdaContext(context, (LambdaFunction) arg, arguments, node.getFuncName());
446-
arguments.add(analyze(arg, lambdaContext));
455+
prepareLambdaContext(
456+
context, (LambdaFunction) arg, arguments, node.getFuncName(), null);
457+
RexNode lambdaNode = analyze(arg, lambdaContext);
458+
if (node.getFuncName().equalsIgnoreCase("reduce")) { // analyze again with calculate type
459+
lambdaContext =
460+
prepareLambdaContext(
461+
context,
462+
(LambdaFunction) arg,
463+
arguments,
464+
node.getFuncName(),
465+
lambdaNode.getType());
466+
lambdaNode = analyze(arg, lambdaContext);
467+
}
468+
arguments.add(lambdaNode);
447469
} else {
448470
arguments.add(analyze(arg, context));
449471
}

core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ReduceFunctionImpl.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,6 @@ public SqlReturnTypeInference getReturnTypeInference() {
6060
finalReturnType = mergedReturnType;
6161
}
6262
return finalReturnType;
63-
64-
/*
65-
RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory();
66-
RexCallBinding rexCallBinding = (RexCallBinding) sqlOperatorBinding;
67-
List<RexNode> operands = rexCallBinding.operands();
68-
RelDataType mergedReturnType =
69-
((RexLambda) operands.get(2)).getExpression().getType();
70-
if (operands.size() > 3) {
71-
RelDataType reduceReturnType =
72-
((RexLambda) operands.get(3)).getExpression().getType();
73-
return typeFactory.leastRestrictive(List.of(mergedReturnType, reduceReturnType));
74-
}
75-
return mergedReturnType;
76-
77-
*/
7863
};
7964
}
8065

core/src/test/java/org/opensearch/sql/calcite/CalciteRexNodeVisitorTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public void testPrepareLambdaForBasicLambda() {
8585
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1));
8686

8787
CalcitePlanContext lambdaContext =
88-
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "forall");
88+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "forall", null);
8989

9090
assertNotNull(lambdaContext);
9191
assertNotNull(lambdaContext.getRexLambdaRefMap());
@@ -108,7 +108,7 @@ public void testPrepareLambdaForTransform() {
108108
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1, functionArg2));
109109

110110
CalcitePlanContext lambdaContext =
111-
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "transform");
111+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "transform", null);
112112

113113
assertNotNull(lambdaContext);
114114
assertNotNull(lambdaContext.getRexLambdaRefMap());
@@ -137,7 +137,7 @@ public void testPrepareLambdaForReduce() {
137137
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1, functionArg2));
138138

139139
CalcitePlanContext lambdaContext =
140-
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "reduce");
140+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "reduce", null);
141141

142142
assertNotNull(lambdaContext);
143143
assertNotNull(lambdaContext.getRexLambdaRefMap());
@@ -165,7 +165,7 @@ public void testPrepareLambdaForReduceFinalizerFunction() {
165165
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1));
166166

167167
CalcitePlanContext lambdaContext =
168-
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "reduce");
168+
visitor.prepareLambdaContext(context, lambdaFunction, previousArguments, "reduce", null);
169169

170170
assertNotNull(lambdaContext);
171171
assertNotNull(lambdaContext.getRexLambdaRefMap());

0 commit comments

Comments
 (0)