55
66package org .opensearch .sql .api .spec .datetime ;
77
8+ import java .util .List ;
89import java .util .Optional ;
910import org .apache .calcite .rel .RelHomogeneousShuttle ;
1011import org .apache .calcite .rel .RelNode ;
12+ import org .apache .calcite .rel .core .Filter ;
13+ import org .apache .calcite .rel .core .Join ;
14+ import org .apache .calcite .rel .core .Project ;
1115import org .apache .calcite .rel .type .RelDataType ;
1216import org .apache .calcite .rel .type .RelDataTypeFactory ;
17+ import org .apache .calcite .rel .type .RelDataTypeField ;
1318import org .apache .calcite .rex .RexBuilder ;
1419import org .apache .calcite .rex .RexCall ;
20+ import org .apache .calcite .rex .RexInputRef ;
1521import org .apache .calcite .rex .RexNode ;
1622import org .apache .calcite .rex .RexShuttle ;
23+ import org .apache .calcite .rex .RexUtil ;
1724import org .apache .calcite .sql .type .SqlTypeName ;
1825import org .opensearch .sql .api .spec .datetime .DatetimeExtension .UdtMapping ;
1926
2027/**
2128 * Temporary patch that rewrites datetime UDT return types on RexCall nodes to standard Calcite
22- * types.
29+ * types. Also re-aligns {@link RexInputRef} declared types against the (already-normalized) child's
30+ * row type so a parent {@link org.apache.calcite.rel.core.Filter}/{@link
31+ * org.apache.calcite.rel.core.Project}/{@link org.apache.calcite.rel.core.Aggregate} constructor
32+ * assertion does not see {@code ref:EXPR_TIMESTAMP VARCHAR input:TIMESTAMP(9)} type mismatch when
33+ * an upstream node just got its UDT-typed column normalized to a standard type.
2334 *
2435 * <p>Not a singleton: {@link RelHomogeneousShuttle} inherits a stateful {@code stack} field from
2536 * {@link org.apache.calcite.rel.RelShuttleImpl}, so a fresh instance must be used per plan().
@@ -28,31 +39,147 @@ class DatetimeUdtNormalizeRule extends RelHomogeneousShuttle {
2839
2940 @ Override
3041 public RelNode visit (RelNode other ) {
31- RelNode visited = super .visit (other );
32- RexBuilder rexBuilder = visited .getCluster ().getRexBuilder ();
42+ // Recurse into children first so each child's row type is fully normalized; then re-align
43+ // the parent's RexNodes (RexCall return types AND RexInputRef stored types) against the new
44+ // input schema BEFORE invoking the parent's copy(). Going through super.visit() would call
45+ // parent.copy(traitSet, inputs) right after each child swap, firing the parent's
46+ // constructor assertion (ref:EXPR_TIMESTAMP VARCHAR vs input:TIMESTAMP(9)) before we get to
47+ // patch the stale RexInputRefs.
48+ List <RelNode > normalizedChildren = recurseChildren (other );
49+ boolean inputsChanged = false ;
50+ for (int i = 0 ; i < normalizedChildren .size (); i ++) {
51+ if (normalizedChildren .get (i ) != other .getInputs ().get (i )) {
52+ inputsChanged = true ;
53+ break ;
54+ }
55+ }
56+ RexBuilder rexBuilder = other .getCluster ().getRexBuilder ();
3357 RelDataTypeFactory typeFactory = rexBuilder .getTypeFactory ();
34- return visited .accept (
35- new RexShuttle () {
36- @ Override
37- public RexNode visitCall (RexCall call ) {
38- call = (RexCall ) super .visitCall (call );
39- Optional <UdtMapping > mapping = UdtMapping .fromUdtType (call .getType ());
40- if (mapping .isEmpty ()) {
41- return call ;
42- }
43-
44- // Normalize UDT return type to standard Calcite DATE/TIME/TIMESTAMP
45- UdtMapping m = mapping .get ();
46- SqlTypeName stdTypeName = m .getStdType ();
47- RelDataType baseType =
48- stdTypeName .allowsPrec ()
49- ? typeFactory .createSqlType (
50- stdTypeName , typeFactory .getTypeSystem ().getMaxPrecision (stdTypeName ))
51- : typeFactory .createSqlType (stdTypeName );
52- RelDataType stdType =
53- typeFactory .createTypeWithNullability (baseType , call .getType ().isNullable ());
54- return call .clone (stdType , call .getOperands ());
55- }
56- });
58+ List <RelDataType > inputFieldTypes = concatFieldTypes (normalizedChildren );
59+ NormalizeShuttle shuttle = new NormalizeShuttle (typeFactory , inputFieldTypes );
60+ return rebuild (other , normalizedChildren , inputsChanged , shuttle );
61+ }
62+
63+ /** Recurse into each child via {@link #visit(RelNode)}, returning the (possibly new) children. */
64+ private List <RelNode > recurseChildren (RelNode rel ) {
65+ java .util .ArrayList <RelNode > out = new java .util .ArrayList <>(rel .getInputs ().size ());
66+ stack .push (rel );
67+ try {
68+ for (RelNode input : rel .getInputs ()) {
69+ out .add (input .accept (this ));
70+ }
71+ } finally {
72+ stack .pop ();
73+ }
74+ return out ;
75+ }
76+
77+ /**
78+ * Concatenated field types of all children, in input-index order; matches RexInputRef indexing.
79+ */
80+ private static List <RelDataType > concatFieldTypes (List <RelNode > children ) {
81+ java .util .ArrayList <RelDataType > out = new java .util .ArrayList <>();
82+ for (RelNode child : children ) {
83+ for (RelDataTypeField f : child .getRowType ().getFieldList ()) {
84+ out .add (f .getType ());
85+ }
86+ }
87+ return out ;
88+ }
89+
90+ /**
91+ * Reassemble {@code rel} with normalized children + RexShuttle-rewritten RexNodes. We dispatch
92+ * per-RelNode type to the {@code copy(traits, input, ...rex)} variant that takes both new input
93+ * and new rex args together — using {@code copy(traits, inputs)} would copy the original (stale)
94+ * RexNodes with the new input, firing the parent's constructor assertion.
95+ */
96+ private static RelNode rebuild (
97+ RelNode rel , List <RelNode > children , boolean inputsChanged , NormalizeShuttle shuttle ) {
98+ if (rel instanceof Project project ) {
99+ List <RexNode > rewrittenExps = shuttle .apply (project .getProjects ());
100+ boolean expsChanged = rewrittenExps != project .getProjects ();
101+ if (!inputsChanged && !expsChanged ) {
102+ return project ;
103+ }
104+ RelDataType newRowType =
105+ RexUtil .createStructType (
106+ shuttle .typeFactory , rewrittenExps , project .getRowType ().getFieldNames (), null );
107+ return project .copy (project .getTraitSet (), children .get (0 ), rewrittenExps , newRowType );
108+ }
109+ if (rel instanceof Filter filter ) {
110+ RexNode rewrittenCondition = filter .getCondition ().accept (shuttle );
111+ boolean conditionChanged = rewrittenCondition != filter .getCondition ();
112+ if (!inputsChanged && !conditionChanged ) {
113+ return filter ;
114+ }
115+ return filter .copy (filter .getTraitSet (), children .get (0 ), rewrittenCondition );
116+ }
117+ if (rel instanceof Join join ) {
118+ RexNode rewrittenCondition = join .getCondition ().accept (shuttle );
119+ boolean conditionChanged = rewrittenCondition != join .getCondition ();
120+ if (!inputsChanged && !conditionChanged ) {
121+ return join ;
122+ }
123+ return join .copy (
124+ join .getTraitSet (),
125+ rewrittenCondition ,
126+ children .get (0 ),
127+ children .get (1 ),
128+ join .getJoinType (),
129+ join .isSemiJoinDone ());
130+ }
131+ // Aggregate, Sort, TableScan, Union, etc.: row-type-stable nodes (no RexNodes carrying stale
132+ // input refs in a way that requires per-type rebuild) — accept(RexShuttle) handles their
133+ // RexNode payloads (e.g. Sort collations) and we plug the new inputs via copy(traits, inputs).
134+ RelNode withRex = rel .accept (shuttle );
135+ if (!inputsChanged ) {
136+ return withRex ;
137+ }
138+ return withRex .copy (withRex .getTraitSet (), children );
139+ }
140+
141+ /** Rewrites UDT return types on calls and stale UDT types on input refs to the standard type. */
142+ private static final class NormalizeShuttle extends RexShuttle {
143+ private final RelDataTypeFactory typeFactory ;
144+ private final List <RelDataType > inputFieldTypes ;
145+
146+ NormalizeShuttle (RelDataTypeFactory typeFactory , List <RelDataType > inputFieldTypes ) {
147+ this .typeFactory = typeFactory ;
148+ this .inputFieldTypes = inputFieldTypes ;
149+ }
150+
151+ @ Override
152+ public RexNode visitCall (RexCall call ) {
153+ call = (RexCall ) super .visitCall (call );
154+ Optional <UdtMapping > mapping = UdtMapping .fromUdtType (call .getType ());
155+ if (mapping .isEmpty ()) {
156+ return call ;
157+ }
158+ return call .clone (toStdType (call .getType (), mapping .get ()), call .getOperands ());
159+ }
160+
161+ @ Override
162+ public RexNode visitInputRef (RexInputRef inputRef ) {
163+ // Re-align stored type against the post-normalization input field type at this index.
164+ int index = inputRef .getIndex ();
165+ if (index < 0 || index >= inputFieldTypes .size ()) {
166+ return inputRef ;
167+ }
168+ RelDataType actual = inputFieldTypes .get (index );
169+ if (actual .equals (inputRef .getType ())) {
170+ return inputRef ;
171+ }
172+ return new RexInputRef (index , actual );
173+ }
174+
175+ private RelDataType toStdType (RelDataType original , UdtMapping mapping ) {
176+ SqlTypeName stdTypeName = mapping .getStdType ();
177+ RelDataType baseType =
178+ stdTypeName .allowsPrec ()
179+ ? typeFactory .createSqlType (
180+ stdTypeName , typeFactory .getTypeSystem ().getMaxPrecision (stdTypeName ))
181+ : typeFactory .createSqlType (stdTypeName );
182+ return typeFactory .createTypeWithNullability (baseType , original .isNullable ());
183+ }
57184 }
58185}
0 commit comments