55
66package org .opensearch .sql .expression .function ;
77
8+ import static org .opensearch .sql .data .type .ExprCoreType .UNKNOWN ;
9+
10+ import com .google .common .annotations .VisibleForTesting ;
811import java .util .ArrayList ;
912import java .util .List ;
13+ import java .util .Optional ;
14+ import java .util .PriorityQueue ;
15+ import java .util .Set ;
16+ import java .util .function .BiPredicate ;
17+ import java .util .function .BinaryOperator ;
1018import java .util .stream .Collectors ;
1119import javax .annotation .Nullable ;
1220import org .apache .calcite .rex .RexBuilder ;
1321import org .apache .calcite .rex .RexNode ;
22+ import org .apache .commons .lang3 .tuple .Pair ;
1423import org .opensearch .sql .calcite .utils .OpenSearchTypeFactory ;
1524import org .opensearch .sql .data .type .ExprCoreType ;
1625import org .opensearch .sql .data .type .ExprType ;
17- import org .opensearch .sql .data .type .WideningTypeRule ;
1826import org .opensearch .sql .exception .ExpressionEvaluationException ;
1927
20- public class CoercionUtils {
21-
28+ public final class CoercionUtils {
2229 /**
2330 * Casts the arguments to the types specified in the typeChecker. Returns null if no combination
2431 * of parameter types matches the arguments or if casting fails.
@@ -32,15 +39,26 @@ public class CoercionUtils {
3239 RexBuilder builder , PPLTypeChecker typeChecker , List <RexNode > arguments ) {
3340 List <List <ExprType >> paramTypeCombinations = typeChecker .getParameterTypes ();
3441
35- // TODO: var args?
36-
42+ List <ExprType > sourceTypes =
43+ arguments .stream ()
44+ .map (node -> OpenSearchTypeFactory .convertRelDataTypeToExprType (node .getType ()))
45+ .collect (Collectors .toList ());
46+ // Candidate parameter signatures ordered by decreasing widening distance
47+ PriorityQueue <Pair <List <ExprType >, Integer >> rankedSignatures =
48+ new PriorityQueue <>((left , right ) -> Integer .compare (right .getValue (), left .getValue ()));
3749 for (List <ExprType > paramTypes : paramTypeCombinations ) {
38- List < RexNode > castedArguments = castArguments ( builder , paramTypes , arguments );
39- if (castedArguments != null ) {
40- return castedArguments ;
50+ int distance = distance ( sourceTypes , paramTypes );
51+ if (distance == TYPE_EQUAL ) {
52+ return castArguments ( builder , paramTypes , arguments ) ;
4153 }
54+ Optional .of (distance )
55+ .filter (value -> value != IMPOSSIBLE_WIDENING )
56+ .ifPresent (value -> rankedSignatures .add (Pair .of (paramTypes , value )));
4257 }
43- return null ;
58+ return Optional .ofNullable (rankedSignatures .peek ())
59+ .map (Pair ::getKey )
60+ .map (paramTypes -> castArguments (builder , paramTypes , arguments ))
61+ .orElse (null );
4462 }
4563
4664 /**
@@ -91,11 +109,16 @@ public class CoercionUtils {
91109 if (!argType .shouldCast (targetType )) {
92110 return arg ;
93111 }
94-
95- if ( WideningTypeRule . distance ( argType , targetType ) != WideningTypeRule . IMPOSSIBLE_WIDENING ) {
96- return builder . makeCast ( OpenSearchTypeFactory .convertExprTypeToRelDataType (targetType ), arg );
112+ if ( distance ( argType , targetType ) != IMPOSSIBLE_WIDENING ) {
113+ return builder . makeCast (
114+ OpenSearchTypeFactory .convertExprTypeToRelDataType (targetType ), arg , true , true );
97115 }
98- return null ;
116+ return resolveCommonType (argType , targetType )
117+ .map (
118+ exprType ->
119+ builder .makeCast (
120+ OpenSearchTypeFactory .convertExprTypeToRelDataType (exprType ), arg , true , true ))
121+ .orElse (null );
99122 }
100123
101124 /**
@@ -119,12 +142,8 @@ public class CoercionUtils {
119142 for (int i = 1 ; i < arguments .size (); i ++) {
120143 var type = OpenSearchTypeFactory .convertRelDataTypeToExprType (arguments .get (i ).getType ());
121144 try {
122- if (areDateAndTime (widestType , type )) {
123- // If one is date and the other is time, we consider timestamp as the widest type
124- widestType = ExprCoreType .TIMESTAMP ;
125- } else {
126- widestType = WideningTypeRule .max (widestType , type );
127- }
145+ final ExprType tempType = widestType ;
146+ widestType = resolveCommonType (widestType , type ).orElseGet (() -> max (tempType , type ));
128147 } catch (ExpressionEvaluationException e ) {
129148 // the two types are not compatible, return null
130149 return null ;
@@ -137,4 +156,125 @@ private static boolean areDateAndTime(ExprType type1, ExprType type2) {
137156 return (type1 == ExprCoreType .DATE && type2 == ExprCoreType .TIME )
138157 || (type1 == ExprCoreType .TIME && type2 == ExprCoreType .DATE );
139158 }
159+
160+ @ VisibleForTesting
161+ public static Optional <ExprType > resolveCommonType (ExprType left , ExprType right ) {
162+ return COMMON_COERCION_RULES .stream ()
163+ .map (rule -> rule .apply (left , right ))
164+ .flatMap (Optional ::stream )
165+ .findFirst ();
166+ }
167+
168+ public static boolean hasString (List <RexNode > rexNodeList ) {
169+ return rexNodeList .stream ()
170+ .map (RexNode ::getType )
171+ .map (OpenSearchTypeFactory ::convertRelDataTypeToExprType )
172+ .anyMatch (t -> t == ExprCoreType .STRING );
173+ }
174+
175+ private static final Set <ExprType > NUMBER_TYPES = ExprCoreType .numberTypes ();
176+
177+ private static final List <CoercionRule > COMMON_COERCION_RULES =
178+ List .of (
179+ CoercionRule .of (
180+ (left , right ) -> areDateAndTime (left , right ),
181+ (left , right ) -> ExprCoreType .TIMESTAMP ),
182+ CoercionRule .of (
183+ (left , right ) -> hasString (left , right ) && hasNumber (left , right ),
184+ (left , right ) -> ExprCoreType .DOUBLE ));
185+
186+ private static boolean hasString (ExprType left , ExprType right ) {
187+ return left == ExprCoreType .STRING || right == ExprCoreType .STRING ;
188+ }
189+
190+ private static boolean hasNumber (ExprType left , ExprType right ) {
191+ return NUMBER_TYPES .contains (left ) || NUMBER_TYPES .contains (right );
192+ }
193+
194+ private static boolean hasBoolean (ExprType left , ExprType right ) {
195+ return left == ExprCoreType .BOOLEAN || right == ExprCoreType .BOOLEAN ;
196+ }
197+
198+ private static class CoercionRule {
199+ private final BiPredicate <ExprType , ExprType > predicate ;
200+ private final BinaryOperator <ExprType > resolver ;
201+
202+ public CoercionRule (BiPredicate <ExprType , ExprType > predicate , BinaryOperator <ExprType > resolver ) {
203+ this .predicate = predicate ;
204+ this .resolver = resolver ;
205+ }
206+
207+ Optional <ExprType > apply (ExprType left , ExprType right ) {
208+ return predicate .test (left , right )
209+ ? Optional .of (resolver .apply (left , right ))
210+ : Optional .empty ();
211+ }
212+
213+ static CoercionRule of (
214+ BiPredicate <ExprType , ExprType > predicate , BinaryOperator <ExprType > resolver ) {
215+ return new CoercionRule (predicate , resolver );
216+ }
217+ }
218+
219+ private static final int IMPOSSIBLE_WIDENING = Integer .MAX_VALUE ;
220+ private static final int TYPE_EQUAL = 0 ;
221+
222+ private static int distance (ExprType type1 , ExprType type2 ) {
223+ return distance (type1 , type2 , TYPE_EQUAL );
224+ }
225+
226+ private static int distance (ExprType type1 , ExprType type2 , int distance ) {
227+ if (type1 == type2 ) {
228+ return distance ;
229+ } else if (type1 == UNKNOWN ) {
230+ return IMPOSSIBLE_WIDENING ;
231+ } else if (type1 == ExprCoreType .STRING && type2 == ExprCoreType .DOUBLE ) {
232+ return 1 ;
233+ } else {
234+ return type1 .getParent ().stream ()
235+ .map (parentOfType1 -> distance (parentOfType1 , type2 , distance + 1 ))
236+ .reduce (Math ::min )
237+ .get ();
238+ }
239+ }
240+
241+ /**
242+ * The max type among two types. The max is defined as follow if type1 could widen to type2, then
243+ * max is type2, vice versa if type1 couldn't widen to type2 and type2 could't widen to type1,
244+ * then throw {@link ExpressionEvaluationException}.
245+ *
246+ * @param type1 type1
247+ * @param type2 type2
248+ * @return the max type among two types.
249+ */
250+ public static ExprType max (ExprType type1 , ExprType type2 ) {
251+ int type1To2 = distance (type1 , type2 );
252+ int type2To1 = distance (type2 , type1 );
253+
254+ if (type1To2 == Integer .MAX_VALUE && type2To1 == Integer .MAX_VALUE ) {
255+ throw new ExpressionEvaluationException (
256+ String .format ("no max type of %s and %s " , type1 , type2 ));
257+ } else {
258+ return type1To2 == Integer .MAX_VALUE ? type1 : type2 ;
259+ }
260+ }
261+
262+ public static int distance (List <ExprType > sourceTypes , List <ExprType > targetTypes ) {
263+ if (sourceTypes .size () != targetTypes .size ()) {
264+ return IMPOSSIBLE_WIDENING ;
265+ }
266+
267+ int totalDistance = 0 ;
268+ for (int i = 0 ; i < sourceTypes .size (); i ++) {
269+ ExprType source = sourceTypes .get (i );
270+ ExprType target = targetTypes .get (i );
271+ int distance = distance (source , target );
272+ if (distance == IMPOSSIBLE_WIDENING ) {
273+ return IMPOSSIBLE_WIDENING ;
274+ } else {
275+ totalDistance += distance ;
276+ }
277+ }
278+ return totalDistance ;
279+ }
140280}
0 commit comments