2020import java .util .Map ;
2121import java .util .stream .Collectors ;
2222import java .util .stream .IntStream ;
23+ import javax .annotation .Nullable ;
2324import lombok .RequiredArgsConstructor ;
2425import org .apache .calcite .rel .RelNode ;
2526import org .apache .calcite .rel .type .RelDataType ;
@@ -381,15 +382,17 @@ public CalcitePlanContext prepareLambdaContext(
381382 CalcitePlanContext context ,
382383 LambdaFunction node ,
383384 List <RexNode > previousArgument ,
384- String functionName ) {
385+ String functionName ,
386+ @ Nullable RelDataType defaultTypeForReduceAcc ) {
385387 try {
386388 CalcitePlanContext lambdaContext = context .clone ();
387389 List <RelDataType > candidateType = new ArrayList <>();
388390 candidateType .add (
389391 ((ArraySqlType ) previousArgument .get (0 ).getType ())
390392 .getComponentType ()); // The first argument should be array type
391393 candidateType .addAll (previousArgument .stream ().skip (1 ).map (RexNode ::getType ).toList ());
392- candidateType = modifyLambdaTypeByFunction (functionName , candidateType );
394+ candidateType =
395+ modifyLambdaTypeByFunction (functionName , candidateType , defaultTypeForReduceAcc );
393396 List <QualifiedName > argNames = node .getFuncArgs ();
394397 Map <String , RexLambdaRef > lambdaTypes = new HashMap <>();
395398 int candidateIndex ;
@@ -421,12 +424,19 @@ public CalcitePlanContext prepareLambdaContext(
421424 * reduce has special logic.
422425 */
423426 private List <RelDataType > modifyLambdaTypeByFunction (
424- String functionName , List <RelDataType > originalType ) {
427+ String functionName ,
428+ List <RelDataType > originalType ,
429+ @ Nullable RelDataType defaultTypeForReduceAcc ) {
425430 switch (functionName .toUpperCase (Locale .ROOT )) {
426431 case "REDUCE" : // For reduce case, the first type is acc should be any since it is the output
427432 // of accumulator lambda function
428433 if (originalType .size () == 2 ) {
429- return List .of (originalType .get (1 ), originalType .get (0 ));
434+ if (defaultTypeForReduceAcc == null
435+ || defaultTypeForReduceAcc .equals (originalType .get (1 ))) {
436+ return List .of (originalType .get (1 ), originalType .get (0 ));
437+ }
438+ return List .of (TYPE_FACTORY .createSqlType (SqlTypeName .ANY , true ), originalType .get (0 ));
439+
430440 } else {
431441 return List .of (originalType .get (2 ));
432442 }
@@ -442,8 +452,20 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
442452 for (UnresolvedExpression arg : args ) {
443453 if (arg instanceof LambdaFunction ) {
444454 CalcitePlanContext lambdaContext =
445- prepareLambdaContext (context , (LambdaFunction ) arg , arguments , node .getFuncName ());
446- arguments .add (analyze (arg , lambdaContext ));
455+ prepareLambdaContext (
456+ context , (LambdaFunction ) arg , arguments , node .getFuncName (), null );
457+ RexNode lambdaNode = analyze (arg , lambdaContext );
458+ if (node .getFuncName ().equalsIgnoreCase ("reduce" )) { // analyze again with calculate type
459+ lambdaContext =
460+ prepareLambdaContext (
461+ context ,
462+ (LambdaFunction ) arg ,
463+ arguments ,
464+ node .getFuncName (),
465+ lambdaNode .getType ());
466+ lambdaNode = analyze (arg , lambdaContext );
467+ }
468+ arguments .add (lambdaNode );
447469 } else {
448470 arguments .add (analyze (arg , context ));
449471 }
0 commit comments