Skip to content

Commit e4f8e5e

Browse files
committed
optimize reduce logic
Signed-off-by: xinyual <xinyual@amazon.com>
1 parent e7acda2 commit e4f8e5e

6 files changed

Lines changed: 87 additions & 8 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.calcite.rel.type.RelDataTypeFactory;
2828
import org.apache.calcite.rex.RexBuilder;
2929
import org.apache.calcite.rex.RexCall;
30+
import org.apache.calcite.rex.RexLambda;
3031
import org.apache.calcite.rex.RexLambdaRef;
3132
import org.apache.calcite.rex.RexNode;
3233
import org.apache.calcite.sql.SqlIntervalQualifier;
@@ -430,12 +431,10 @@ private List<RelDataType> modifyLambdaTypeByFunction(
430431
case "REDUCE": // For reduce case, the first type is acc should be any since it is the output
431432
// of accumulator lambda function
432433
if (originalType.size() == 2) {
433-
if (defaultTypeForReduceAcc == null
434-
|| defaultTypeForReduceAcc.equals(originalType.get(1))) {
435-
return List.of(originalType.get(1), originalType.get(0));
434+
if (defaultTypeForReduceAcc != null) {
435+
return List.of(defaultTypeForReduceAcc, originalType.get(0));
436436
}
437-
return List.of(TYPE_FACTORY.createSqlType(SqlTypeName.ANY, true), originalType.get(0));
438-
437+
return List.of(originalType.get(1), originalType.get(0));
439438
} else {
440439
return List.of(originalType.get(2));
441440
}
@@ -444,6 +443,19 @@ private List<RelDataType> modifyLambdaTypeByFunction(
444443
}
445444
}
446445

446+
private List<RexNode> castArgument(
447+
List<RexNode> originalArguments, String functionName, ExtendedRexBuilder rexBuilder) {
448+
switch (functionName.toUpperCase(Locale.ROOT)) {
449+
case "REDUCE":
450+
RexLambda call = (RexLambda) originalArguments.get(2);
451+
originalArguments.set(
452+
1, rexBuilder.makeCast(call.getType(), originalArguments.get(1), true, true));
453+
return originalArguments;
454+
default:
455+
return originalArguments;
456+
}
457+
}
458+
447459
@Override
448460
public RexNode visitFunction(Function node, CalcitePlanContext context) {
449461
List<UnresolvedExpression> args = node.getFuncArgs();
@@ -469,6 +481,9 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
469481
arguments.add(analyze(arg, context));
470482
}
471483
}
484+
485+
arguments = castArgument(arguments, node.getFuncName(), context.rexBuilder);
486+
472487
RexNode resolvedNode =
473488
PPLFuncImpTable.INSTANCE.resolve(
474489
context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));

core/src/main/java/org/opensearch/sql/calcite/utils/OpenSearchTypeFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ public static ExprType convertSqlTypeNameToExprType(SqlTypeName sqlTypeName) {
245245
case ARRAY -> ARRAY;
246246
case MAP -> STRUCT;
247247
case GEOMETRY -> GEO_POINT;
248-
case NULL, ANY -> UNDEFINED;
248+
case NULL, ANY, OTHER -> UNDEFINED;
249249
default -> UNKNOWN;
250250
};
251251
}

core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ArrayFunctionImpl.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType;
99

10+
import java.math.BigDecimal;
1011
import java.util.ArrayList;
1112
import java.util.Arrays;
1213
import java.util.List;
@@ -82,7 +83,20 @@ public static Object internalCast(Object... args) {
8283
SqlTypeName targetType = (SqlTypeName) args[args.length - 1];
8384
List<Object> result;
8485
switch (targetType) {
85-
case DOUBLE, DECIMAL:
86+
case DECIMAL:
87+
result =
88+
originalList.stream()
89+
.map(
90+
num -> {
91+
if (num instanceof BigDecimal) {
92+
return (BigDecimal) num;
93+
} else {
94+
return BigDecimal.valueOf(((Number) num).doubleValue());
95+
}
96+
})
97+
.collect(Collectors.toList());
98+
break;
99+
case DOUBLE:
86100
result =
87101
originalList.stream()
88102
.map(i -> (Object) ((Number) i).doubleValue())

core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,9 @@ private static List<ExprType> getExprTypes(SqlTypeFamily family) {
398398
// Integer is mapped to BIGINT in family.getDefaultConcreteType
399399
case INTEGER -> List.of(
400400
OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER));
401-
case ANY, IGNORE -> List.of(
401+
// case DATETIME_INTERVAL ->
402+
// SqlTypeName.INTERVAL_TYPES.stream().map(OpenSearchTypeFactory.TYPE_FACTORY::createSqlIntervalType).toList();
403+
case ANY, IGNORE, DATETIME_INTERVAL -> List.of(
402404
OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.ANY));
403405
default -> {
404406
RelDataType type = family.getDefaultConcreteType(OpenSearchTypeFactory.TYPE_FACTORY);

core/src/test/java/org/opensearch/sql/calcite/CalciteRexNodeVisitorTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,39 @@ public void testPrepareLambdaForReduce() {
152152
SqlTypeName.VARCHAR);
153153
}
154154

155+
@Test
156+
public void testPrepareLambdaForReduceWithDefaultType() {
157+
when(componentType.getSqlTypeName()).thenReturn(SqlTypeName.DOUBLE);
158+
when(arrayArg.getType()).thenReturn(arraySqlType);
159+
when(arraySqlType.getComponentType()).thenReturn(componentType);
160+
when(extraArg.getType()).thenReturn(extraType);
161+
162+
List<RexNode> previousArguments = List.of(arrayArg, extraArg);
163+
when(functionArg1.toString()).thenReturn("acc");
164+
when(functionArg2.toString()).thenReturn("arg1");
165+
when(lambdaFunction.getFuncArgs()).thenReturn(List.of(functionArg1, functionArg2));
166+
167+
CalcitePlanContext lambdaContext =
168+
visitor.prepareLambdaContext(
169+
context,
170+
lambdaFunction,
171+
previousArguments,
172+
"reduce",
173+
TYPE_FACTORY.createSqlType(SqlTypeName.BIGINT));
174+
175+
assertNotNull(lambdaContext);
176+
assertNotNull(lambdaContext.getRexLambdaRefMap());
177+
assertEquals(2, lambdaContext.getRexLambdaRefMap().size());
178+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("arg1"));
179+
assertTrue(lambdaContext.getRexLambdaRefMap().containsKey("acc"));
180+
assertEquals(
181+
lambdaContext.getRexLambdaRefMap().get("arg1").getType().getSqlTypeName(),
182+
SqlTypeName.DOUBLE);
183+
assertEquals(
184+
lambdaContext.getRexLambdaRefMap().get("acc").getType().getSqlTypeName(),
185+
SqlTypeName.BIGINT);
186+
}
187+
155188
@Test
156189
public void testPrepareLambdaForReduceFinalizerFunction() {
157190
when(arrayArg.getType()).thenReturn(arraySqlType);

integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalciteArrayFunctionIT.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,21 @@ public void testReduce2() {
211211
verifyDataRows(actual, rows(1230));
212212
}
213213

214+
@Test
215+
public void testReduce3() {
216+
JSONObject actual =
217+
executeQuery(
218+
String.format(
219+
"source=%s | where age=28 | eval array = array(1.0, 2.0, 3.0), result3 ="
220+
+ " reduce(array, age, (acc, x) -> acc * 1.0 + x, acc -> acc * 10.0) | fields"
221+
+ " result3 | head 1",
222+
TEST_INDEX_BANK));
223+
224+
verifySchema(actual, schema("result3", "double"));
225+
226+
verifyDataRows(actual, rows(340));
227+
}
228+
214229
@Test
215230
public void testReduceWithUDF() {
216231
JSONObject actual =

0 commit comments

Comments
 (0)