Skip to content

Commit 5564501

Browse files
authored
[Backport 2.19-dev] Support Automatic Type Conversion for REX/SPATH/PARSE Command Extraction (opensearch-project#4599) (opensearch-project#4650)
* Support Automatic Type Conversion for REX/SPATH/PARSE Command Extractions (opensearch-project#4599) --------- Signed-off-by: Peng Huo <penghuo@gmail.com>
1 parent b19c967 commit 5564501

11 files changed

Lines changed: 635 additions & 96 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,26 @@ public RexNode visitWindowFunction(WindowFunction node, CalcitePlanContext conte
446446
(arguments.isEmpty() || arguments.size() == 1)
447447
? Collections.emptyList()
448448
: arguments.subList(1, arguments.size());
449-
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(functionName, field, args);
450-
return PlanUtils.makeOver(
451-
context, functionName, field, args, partitions, List.of(), node.getWindowFrame());
449+
List<RexNode> nodes =
450+
PPLFuncImpTable.INSTANCE.validateAggFunctionSignature(
451+
functionName, field, args, context.rexBuilder);
452+
return nodes != null
453+
? PlanUtils.makeOver(
454+
context,
455+
functionName,
456+
nodes.get(0),
457+
nodes.size() <= 1 ? Collections.emptyList() : nodes.subList(1, nodes.size()),
458+
partitions,
459+
List.of(),
460+
node.getWindowFrame())
461+
: PlanUtils.makeOver(
462+
context,
463+
functionName,
464+
field,
465+
args,
466+
partitions,
467+
List.of(),
468+
node.getWindowFrame());
452469
})
453470
.orElseThrow(
454471
() ->

core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java

Lines changed: 159 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,27 @@
55

66
package org.opensearch.sql.expression.function;
77

8+
import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN;
9+
10+
import com.google.common.annotations.VisibleForTesting;
811
import java.util.ArrayList;
912
import java.util.List;
13+
import java.util.Optional;
14+
import java.util.PriorityQueue;
15+
import java.util.Set;
16+
import java.util.function.BiPredicate;
17+
import java.util.function.BinaryOperator;
1018
import java.util.stream.Collectors;
1119
import javax.annotation.Nullable;
1220
import org.apache.calcite.rex.RexBuilder;
1321
import org.apache.calcite.rex.RexNode;
22+
import org.apache.commons.lang3.tuple.Pair;
1423
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
1524
import org.opensearch.sql.data.type.ExprCoreType;
1625
import org.opensearch.sql.data.type.ExprType;
17-
import org.opensearch.sql.data.type.WideningTypeRule;
1826
import org.opensearch.sql.exception.ExpressionEvaluationException;
1927

20-
public class CoercionUtils {
21-
28+
public final class CoercionUtils {
2229
/**
2330
* Casts the arguments to the types specified in the typeChecker. Returns null if no combination
2431
* of parameter types matches the arguments or if casting fails.
@@ -32,15 +39,26 @@ public class CoercionUtils {
3239
RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
3340
List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();
3441

35-
// TODO: var args?
36-
42+
List<ExprType> sourceTypes =
43+
arguments.stream()
44+
.map(node -> OpenSearchTypeFactory.convertRelDataTypeToExprType(node.getType()))
45+
.collect(Collectors.toList());
46+
// Candidate parameter signatures ordered by decreasing widening distance
47+
PriorityQueue<Pair<List<ExprType>, Integer>> rankedSignatures =
48+
new PriorityQueue<>((left, right) -> Integer.compare(right.getValue(), left.getValue()));
3749
for (List<ExprType> paramTypes : paramTypeCombinations) {
38-
List<RexNode> castedArguments = castArguments(builder, paramTypes, arguments);
39-
if (castedArguments != null) {
40-
return castedArguments;
50+
int distance = distance(sourceTypes, paramTypes);
51+
if (distance == TYPE_EQUAL) {
52+
return castArguments(builder, paramTypes, arguments);
4153
}
54+
Optional.of(distance)
55+
.filter(value -> value != IMPOSSIBLE_WIDENING)
56+
.ifPresent(value -> rankedSignatures.add(Pair.of(paramTypes, value)));
4257
}
43-
return null;
58+
return Optional.ofNullable(rankedSignatures.peek())
59+
.map(Pair::getKey)
60+
.map(paramTypes -> castArguments(builder, paramTypes, arguments))
61+
.orElse(null);
4462
}
4563

4664
/**
@@ -91,11 +109,16 @@ public class CoercionUtils {
91109
if (!argType.shouldCast(targetType)) {
92110
return arg;
93111
}
94-
95-
if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
96-
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
112+
if (distance(argType, targetType) != IMPOSSIBLE_WIDENING) {
113+
return builder.makeCast(
114+
OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg, true, true);
97115
}
98-
return null;
116+
return resolveCommonType(argType, targetType)
117+
.map(
118+
exprType ->
119+
builder.makeCast(
120+
OpenSearchTypeFactory.convertExprTypeToRelDataType(exprType), arg, true, true))
121+
.orElse(null);
99122
}
100123

101124
/**
@@ -119,12 +142,8 @@ public class CoercionUtils {
119142
for (int i = 1; i < arguments.size(); i++) {
120143
var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
121144
try {
122-
if (areDateAndTime(widestType, type)) {
123-
// If one is date and the other is time, we consider timestamp as the widest type
124-
widestType = ExprCoreType.TIMESTAMP;
125-
} else {
126-
widestType = WideningTypeRule.max(widestType, type);
127-
}
145+
final ExprType tempType = widestType;
146+
widestType = resolveCommonType(widestType, type).orElseGet(() -> max(tempType, type));
128147
} catch (ExpressionEvaluationException e) {
129148
// the two types are not compatible, return null
130149
return null;
@@ -137,4 +156,125 @@ private static boolean areDateAndTime(ExprType type1, ExprType type2) {
137156
return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME)
138157
|| (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE);
139158
}
159+
160+
@VisibleForTesting
161+
public static Optional<ExprType> resolveCommonType(ExprType left, ExprType right) {
162+
return COMMON_COERCION_RULES.stream()
163+
.map(rule -> rule.apply(left, right))
164+
.flatMap(Optional::stream)
165+
.findFirst();
166+
}
167+
168+
public static boolean hasString(List<RexNode> rexNodeList) {
169+
return rexNodeList.stream()
170+
.map(RexNode::getType)
171+
.map(OpenSearchTypeFactory::convertRelDataTypeToExprType)
172+
.anyMatch(t -> t == ExprCoreType.STRING);
173+
}
174+
175+
private static final Set<ExprType> NUMBER_TYPES = ExprCoreType.numberTypes();
176+
177+
private static final List<CoercionRule> COMMON_COERCION_RULES =
178+
List.of(
179+
CoercionRule.of(
180+
(left, right) -> areDateAndTime(left, right),
181+
(left, right) -> ExprCoreType.TIMESTAMP),
182+
CoercionRule.of(
183+
(left, right) -> hasString(left, right) && hasNumber(left, right),
184+
(left, right) -> ExprCoreType.DOUBLE));
185+
186+
private static boolean hasString(ExprType left, ExprType right) {
187+
return left == ExprCoreType.STRING || right == ExprCoreType.STRING;
188+
}
189+
190+
private static boolean hasNumber(ExprType left, ExprType right) {
191+
return NUMBER_TYPES.contains(left) || NUMBER_TYPES.contains(right);
192+
}
193+
194+
private static boolean hasBoolean(ExprType left, ExprType right) {
195+
return left == ExprCoreType.BOOLEAN || right == ExprCoreType.BOOLEAN;
196+
}
197+
198+
private static class CoercionRule {
199+
private final BiPredicate<ExprType, ExprType> predicate;
200+
private final BinaryOperator<ExprType> resolver;
201+
202+
public CoercionRule(BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {
203+
this.predicate = predicate;
204+
this.resolver = resolver;
205+
}
206+
207+
Optional<ExprType> apply(ExprType left, ExprType right) {
208+
return predicate.test(left, right)
209+
? Optional.of(resolver.apply(left, right))
210+
: Optional.empty();
211+
}
212+
213+
static CoercionRule of(
214+
BiPredicate<ExprType, ExprType> predicate, BinaryOperator<ExprType> resolver) {
215+
return new CoercionRule(predicate, resolver);
216+
}
217+
}
218+
219+
private static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE;
220+
private static final int TYPE_EQUAL = 0;
221+
222+
private static int distance(ExprType type1, ExprType type2) {
223+
return distance(type1, type2, TYPE_EQUAL);
224+
}
225+
226+
private static int distance(ExprType type1, ExprType type2, int distance) {
227+
if (type1 == type2) {
228+
return distance;
229+
} else if (type1 == UNKNOWN) {
230+
return IMPOSSIBLE_WIDENING;
231+
} else if (type1 == ExprCoreType.STRING && type2 == ExprCoreType.DOUBLE) {
232+
return 1;
233+
} else {
234+
return type1.getParent().stream()
235+
.map(parentOfType1 -> distance(parentOfType1, type2, distance + 1))
236+
.reduce(Math::min)
237+
.get();
238+
}
239+
}
240+
241+
/**
242+
* The max type among two types. The max is defined as follow if type1 could widen to type2, then
243+
* max is type2, vice versa if type1 couldn't widen to type2 and type2 could't widen to type1,
244+
* then throw {@link ExpressionEvaluationException}.
245+
*
246+
* @param type1 type1
247+
* @param type2 type2
248+
* @return the max type among two types.
249+
*/
250+
public static ExprType max(ExprType type1, ExprType type2) {
251+
int type1To2 = distance(type1, type2);
252+
int type2To1 = distance(type2, type1);
253+
254+
if (type1To2 == Integer.MAX_VALUE && type2To1 == Integer.MAX_VALUE) {
255+
throw new ExpressionEvaluationException(
256+
String.format("no max type of %s and %s ", type1, type2));
257+
} else {
258+
return type1To2 == Integer.MAX_VALUE ? type1 : type2;
259+
}
260+
}
261+
262+
public static int distance(List<ExprType> sourceTypes, List<ExprType> targetTypes) {
263+
if (sourceTypes.size() != targetTypes.size()) {
264+
return IMPOSSIBLE_WIDENING;
265+
}
266+
267+
int totalDistance = 0;
268+
for (int i = 0; i < sourceTypes.size(); i++) {
269+
ExprType source = sourceTypes.get(i);
270+
ExprType target = targetTypes.get(i);
271+
int distance = distance(source, target);
272+
if (distance == IMPOSSIBLE_WIDENING) {
273+
return IMPOSSIBLE_WIDENING;
274+
} else {
275+
totalDistance += distance;
276+
}
277+
}
278+
return totalDistance;
279+
}
140280
}

0 commit comments

Comments
 (0)