Skip to content

Commit 7be9575

Browse files
committed
feat(api): Add pre-compilation adapter and aggregate fix for datetime UDF
Add preCompilationRules to bridge the type mismatch between normalized standard types (DATE/TIME/TIMESTAMP as int/long) and PPL UDF implementors (which expect and produce String values): 1. DatetimeUdfCompilationAdapterRule inserts CAST nodes around datetime UDFs so implementors receive String input and produce String output, with CASTs bridging int/long <-> String conversion. 2. DatetimeUdtNormalizeRule enhanced to handle LogicalAggregate (rebuild AggregateCall with re-inferred types) and LogicalProject (refresh RexInputRef types from new child row type) to prevent type mismatch assertions when datetime UDF results feed into aggregates. Both fixes are only needed for the UnifiedQueryCompiler (Enumerable) path. The Analytics Engine (Substrait/DataFusion) path is unaffected. Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent 0ff1eec commit 7be9575

4 files changed

Lines changed: 280 additions & 63 deletions

File tree

api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ public List<RelShuttle> postAnalysisRules() {
2525
return List.of(DatetimeUdtNormalizeRule.INSTANCE, DatetimeOutputCastRule.INSTANCE);
2626
}
2727

28+
@Override
29+
public List<RelShuttle> preCompilationRules() {
30+
return List.of(DatetimeUdfCompilationAdapterRule.INSTANCE);
31+
}
32+
2833
/** Maps datetime UDT types to their standard Calcite equivalents. */
2934
@Getter
3035
@RequiredArgsConstructor
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.api.spec.datetime;
7+
8+
import static org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping.isDatetimeType;
9+
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import lombok.AccessLevel;
13+
import lombok.NoArgsConstructor;
14+
import org.apache.calcite.rel.RelHomogeneousShuttle;
15+
import org.apache.calcite.rel.RelNode;
16+
import org.apache.calcite.rel.type.RelDataType;
17+
import org.apache.calcite.rel.type.RelDataTypeFactory;
18+
import org.apache.calcite.rex.RexBuilder;
19+
import org.apache.calcite.rex.RexCall;
20+
import org.apache.calcite.rex.RexNode;
21+
import org.apache.calcite.rex.RexShuttle;
22+
import org.apache.calcite.sql.type.SqlTypeName;
23+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
24+
25+
/**
26+
* Adapts datetime UDF calls for Enumerable compilation. PPL UDF implementors expect String
27+
* input/output, but after normalization the plan uses standard DATE/TIME/TIMESTAMP types
28+
* (int/long). This rule inserts CASTs to bridge the mismatch:
29+
*
30+
* <pre>
31+
* Before: LAST_DAY($2:DATE) : DATE
32+
* After: CAST(LAST_DAY(CAST($2 AS VARCHAR)):VARCHAR AS DATE)
33+
* </pre>
34+
*/
35+
@NoArgsConstructor(access = AccessLevel.PACKAGE)
36+
class DatetimeUdfCompilationAdapterRule extends RelHomogeneousShuttle {
37+
38+
static final DatetimeUdfCompilationAdapterRule INSTANCE = new DatetimeUdfCompilationAdapterRule();
39+
40+
@Override
41+
public RelNode visit(RelNode other) {
42+
RelNode visited = super.visit(other);
43+
RexBuilder rexBuilder = visited.getCluster().getRexBuilder();
44+
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
45+
return visited.accept(
46+
new RexShuttle() {
47+
@Override
48+
public RexNode visitCall(RexCall call) {
49+
call = (RexCall) super.visitCall(call);
50+
if (!(call.getOperator() instanceof SqlUserDefinedFunction)) {
51+
return call;
52+
}
53+
54+
// Adapt operands: CAST(datetime_operand AS VARCHAR) for UDF implementor
55+
List<RexNode> adapted = new ArrayList<>(call.getOperands().size());
56+
boolean operandsChanged = false;
57+
for (RexNode operand : call.getOperands()) {
58+
if (isDatetimeType(operand.getType().getSqlTypeName())) {
59+
RelDataType varcharType =
60+
typeFactory.createTypeWithNullability(
61+
typeFactory.createSqlType(SqlTypeName.VARCHAR),
62+
operand.getType().isNullable());
63+
adapted.add(rexBuilder.makeCast(varcharType, operand));
64+
operandsChanged = true;
65+
} else {
66+
adapted.add(operand);
67+
}
68+
}
69+
70+
// Adapt result: if return type is datetime, wrap call with VARCHAR return + CAST back
71+
if (isDatetimeType(call.getType().getSqlTypeName())) {
72+
RelDataType declaredType = call.getType();
73+
RelDataType varcharType =
74+
typeFactory.createTypeWithNullability(
75+
typeFactory.createSqlType(SqlTypeName.VARCHAR), declaredType.isNullable());
76+
RexCall varcharCall =
77+
call.clone(varcharType, operandsChanged ? adapted : call.getOperands());
78+
return rexBuilder.makeCast(declaredType, varcharCall);
79+
}
80+
81+
return operandsChanged ? call.clone(call.getType(), adapted) : call;
82+
}
83+
});
84+
}
85+
}

api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,21 @@
55

66
package org.opensearch.sql.api.spec.datetime;
77

8+
import java.util.ArrayList;
9+
import java.util.List;
810
import java.util.Optional;
911
import lombok.AccessLevel;
1012
import lombok.NoArgsConstructor;
1113
import org.apache.calcite.rel.RelHomogeneousShuttle;
1214
import org.apache.calcite.rel.RelNode;
15+
import org.apache.calcite.rel.core.AggregateCall;
16+
import org.apache.calcite.rel.logical.LogicalAggregate;
17+
import org.apache.calcite.rel.logical.LogicalProject;
1318
import org.apache.calcite.rel.type.RelDataType;
1419
import org.apache.calcite.rel.type.RelDataTypeFactory;
1520
import org.apache.calcite.rex.RexBuilder;
1621
import org.apache.calcite.rex.RexCall;
22+
import org.apache.calcite.rex.RexInputRef;
1723
import org.apache.calcite.rex.RexNode;
1824
import org.apache.calcite.rex.RexShuttle;
1925
import org.apache.calcite.sql.type.SqlTypeName;
@@ -30,10 +36,82 @@ class DatetimeUdtNormalizeRule extends RelHomogeneousShuttle {
3036

3137
@Override
3238
public RelNode visit(RelNode other) {
33-
RelNode visited = super.visit(other);
34-
RexBuilder rexBuilder = visited.getCluster().getRexBuilder();
39+
// Visit children first
40+
List<RelNode> newInputs = new ArrayList<>();
41+
boolean childChanged = false;
42+
for (RelNode input : other.getInputs()) {
43+
RelNode newInput = input.accept(this);
44+
newInputs.add(newInput);
45+
if (newInput != input) {
46+
childChanged = true;
47+
}
48+
}
49+
50+
// Rebuild current node if children changed
51+
RelNode current = other;
52+
if (childChanged) {
53+
if (current instanceof LogicalAggregate agg) {
54+
// Aggregate needs AggregateCall types rebuilt
55+
RelNode newInput = newInputs.get(0);
56+
List<AggregateCall> newAggCalls =
57+
agg.getAggCallList().stream()
58+
.map(
59+
call ->
60+
AggregateCall.create(
61+
call.getAggregation(),
62+
call.isDistinct(),
63+
call.isApproximate(),
64+
call.ignoreNulls(),
65+
call.rexList,
66+
call.getArgList(),
67+
call.filterArg,
68+
call.distinctKeys,
69+
call.collation,
70+
agg.getGroupCount(),
71+
newInput,
72+
null,
73+
call.getName()))
74+
.toList();
75+
current =
76+
agg.copy(
77+
agg.getTraitSet(), newInput, agg.getGroupSet(), agg.getGroupSets(), newAggCalls);
78+
} else if (current instanceof LogicalProject proj) {
79+
// Project needs RexInputRef types refreshed from new child
80+
RelNode newInput = newInputs.get(0);
81+
RexBuilder rexBuilder = proj.getCluster().getRexBuilder();
82+
List<RexNode> newProjects =
83+
proj.getProjects().stream()
84+
.map(
85+
expr ->
86+
expr.accept(
87+
new RexShuttle() {
88+
@Override
89+
public RexNode visitInputRef(RexInputRef ref) {
90+
RelDataType newType =
91+
newInput
92+
.getRowType()
93+
.getFieldList()
94+
.get(ref.getIndex())
95+
.getType();
96+
if (!newType.equals(ref.getType())) {
97+
return rexBuilder.makeInputRef(newType, ref.getIndex());
98+
}
99+
return ref;
100+
}
101+
}))
102+
.toList();
103+
current =
104+
LogicalProject.create(
105+
newInput, proj.getHints(), newProjects, proj.getRowType().getFieldNames());
106+
} else {
107+
current = current.copy(current.getTraitSet(), newInputs);
108+
}
109+
}
110+
111+
// Apply RexShuttle to normalize UDT types in this node's expressions
112+
RexBuilder rexBuilder = current.getCluster().getRexBuilder();
35113
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
36-
return visited.accept(
114+
return current.accept(
37115
new RexShuttle() {
38116
@Override
39117
public RexNode visitCall(RexCall call) {
@@ -43,7 +121,6 @@ public RexNode visitCall(RexCall call) {
43121
return call;
44122
}
45123

46-
// Normalize UDT return type to standard Calcite DATE/TIME/TIMESTAMP
47124
UdtMapping m = mapping.get();
48125
SqlTypeName stdTypeName = m.getStdType();
49126
RelDataType baseType =

0 commit comments

Comments
 (0)