Skip to content

Commit 18ead6c

Browse files
committed
Support function argument coercion with Calcite (opensearch-project#3914)
* Change the use of SqlTypeFamily.STRING to SqlTypeFamily.CHARACTER as the string family contains binary, which is not expected for most functions Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Implement basic argument type coercion at RelNode level Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Conform type checkers with their definition in documentation - string as an input is removed if it is not in the document - string as an input is kept if it is in the document, even if it can be implicitly cast - use PPLOperandTypes as much as possible Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Implement type widening for comparator functions - Add COMPARATORS set to BuiltinFunctionName for identifying comparison operators - Implement widenArguments method in CoercionUtils to find widest compatible type - Apply type widening to comparator functions before applying type casting - Add detailed JavaDoc to explain coercion methods Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Update error messages of datetime functions with invalid args Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Simplify datetime-string compare logic with implict coercion Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Refactor resolve with coercion Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Move down argument cast for reduce function Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Merge comparators and their IP variants so that coercion works for IP comparison - when not merging, ip comparing will also pass the type checker of Calcite's comparators Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Refactor ip comparator to comparator Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Revert "Refactor ip comparator to comparator" This reverts commit c539056. Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Revert "Merge comparators and their IP variants so that coercion works for IP comparison" This reverts commit bd9f3bb. Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Rule out ip from built-in comparator via its type checker Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Restrict CompareIP's parameter type Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Revert to previous implementation of CompareIpFunction to temporarily fix ip comparison pushdown problems (udt not correctly serialized; ip comparison is not converted to range query) Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Test argument coercion explain Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> * Fix error msg in CalcitePPLFunctionTypeTest Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> --------- Signed-off-by: Yuanchun Shen <yuanchu@amazon.com> (cherry picked from commit 484f49e)
1 parent 54a59d9 commit 18ead6c

39 files changed

Lines changed: 693 additions & 382 deletions

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
1212
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
1313
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
14-
import static org.opensearch.sql.utils.DateTimeUtils.findCastType;
15-
import static org.opensearch.sql.utils.DateTimeUtils.transferCompareForDateRelated;
1614

1715
import java.math.BigDecimal;
1816
import java.util.ArrayList;
@@ -23,7 +21,6 @@
2321

2422
import java.util.Locale;
2523
import java.util.Map;
26-
import java.util.stream.Collectors;
2724
import java.util.stream.IntStream;
2825
import javax.annotation.Nullable;
2926
import lombok.RequiredArgsConstructor;
@@ -32,7 +29,6 @@
3229
import org.apache.calcite.rel.type.RelDataTypeFactory;
3330
import org.apache.calcite.rex.RexBuilder;
3431
import org.apache.calcite.rex.RexCall;
35-
import org.apache.calcite.rex.RexLambda;
3632
import org.apache.calcite.rex.RexLambdaRef;
3733
import org.apache.calcite.rex.RexNode;
3834
import org.apache.calcite.sql.SqlIntervalQualifier;
@@ -217,11 +213,8 @@ public RexNode visitIn(In node, CalcitePlanContext context) {
217213

218214
@Override
219215
public RexNode visitCompare(Compare node, CalcitePlanContext context) {
220-
RexNode leftCandidate = analyze(node.getLeft(), context);
221-
RexNode rightCandidate = analyze(node.getRight(), context);
222-
SqlTypeName castTarget = findCastType(leftCandidate, rightCandidate);
223-
final RexNode left = transferCompareForDateRelated(leftCandidate, context, castTarget);
224-
final RexNode right = transferCompareForDateRelated(rightCandidate, context, castTarget);
216+
RexNode left = analyze(node.getLeft(), context);
217+
RexNode right = analyze(node.getRight(), context);
225218
return PPLFuncImpTable.INSTANCE.resolve(context.rexBuilder, node.getOperator(), left, right);
226219
}
227220

@@ -470,19 +463,6 @@ private List<RelDataType> modifyLambdaTypeByFunction(
470463
}
471464
}
472465

473-
private List<RexNode> castArgument(
474-
List<RexNode> originalArguments, String functionName, ExtendedRexBuilder rexBuilder) {
475-
switch (functionName.toUpperCase(Locale.ROOT)) {
476-
case "REDUCE":
477-
RexLambda call = (RexLambda) originalArguments.get(2);
478-
originalArguments.set(
479-
1, rexBuilder.makeCast(call.getType(), originalArguments.get(1), true, true));
480-
return originalArguments;
481-
default:
482-
return originalArguments;
483-
}
484-
}
485-
486466
@Override
487467
public RexNode visitFunction(Function node, CalcitePlanContext context) {
488468
List<UnresolvedExpression> args = node.getFuncArgs();
@@ -509,7 +489,6 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
509489
}
510490
}
511491

512-
arguments = castArgument(arguments, node.getFuncName(), context.rexBuilder);
513492
RexNode resolvedNode =
514493
PPLFuncImpTable.INSTANCE.resolve(
515494
context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));

core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,40 +27,121 @@ private PPLOperandTypes() {}
2727
UDFOperandMetadata.wrap(
2828
(CompositeOperandTypeChecker) OperandTypes.INTEGER.or(OperandTypes.family()));
2929
public static final UDFOperandMetadata STRING =
30-
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.STRING);
30+
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER);
3131
public static final UDFOperandMetadata INTEGER =
3232
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER);
3333
public static final UDFOperandMetadata NUMERIC =
3434
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC);
35+
36+
public static final UDFOperandMetadata NUMERIC_OPTIONAL_STRING =
37+
UDFOperandMetadata.wrap(
38+
(CompositeOperandTypeChecker)
39+
OperandTypes.NUMERIC.or(
40+
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)));
41+
3542
public static final UDFOperandMetadata INTEGER_INTEGER =
3643
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.INTEGER_INTEGER);
3744
public static final UDFOperandMetadata STRING_STRING =
38-
UDFOperandMetadata.wrap(OperandTypes.STRING_STRING);
45+
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.CHARACTER_CHARACTER);
3946
public static final UDFOperandMetadata NUMERIC_NUMERIC =
4047
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC_NUMERIC);
48+
public static final UDFOperandMetadata STRING_INTEGER =
49+
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER));
50+
4151
public static final UDFOperandMetadata NUMERIC_NUMERIC_NUMERIC =
4252
UDFOperandMetadata.wrap(
4353
OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC));
54+
public static final UDFOperandMetadata STRING_OR_INTEGER_INTEGER_INTEGER =
55+
UDFOperandMetadata.wrap(
56+
(CompositeOperandTypeChecker)
57+
OperandTypes.family(
58+
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)
59+
.or(
60+
OperandTypes.family(
61+
SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER)));
62+
63+
public static final UDFOperandMetadata OPTIONAL_DATE_OR_TIMESTAMP_OR_NUMERIC =
64+
UDFOperandMetadata.wrap(
65+
(CompositeOperandTypeChecker)
66+
OperandTypes.DATETIME.or(OperandTypes.NUMERIC).or(OperandTypes.family()));
4467

4568
public static final UDFOperandMetadata DATETIME_OR_STRING =
4669
UDFOperandMetadata.wrap(
47-
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.STRING));
70+
(CompositeOperandTypeChecker) OperandTypes.DATETIME.or(OperandTypes.CHARACTER));
71+
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
72+
UDFOperandMetadata.wrap(
73+
(CompositeOperandTypeChecker)
74+
OperandTypes.CHARACTER.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
75+
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
76+
UDFOperandMetadata.wrap(
77+
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.CHARACTER));
78+
public static final UDFOperandMetadata DATETIME_OR_STRING_OR_INTEGER =
79+
UDFOperandMetadata.wrap(
80+
(CompositeOperandTypeChecker)
81+
OperandTypes.DATETIME.or(OperandTypes.CHARACTER).or(OperandTypes.INTEGER));
82+
83+
public static final UDFOperandMetadata DATETIME_OPTIONAL_INTEGER =
84+
UDFOperandMetadata.wrap(
85+
(CompositeOperandTypeChecker)
86+
OperandTypes.DATETIME.or(
87+
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.INTEGER)));
88+
4889
public static final UDFOperandMetadata DATETIME_DATETIME =
4990
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME));
91+
public static final UDFOperandMetadata DATETIME_OR_STRING_STRING =
92+
UDFOperandMetadata.wrap(
93+
(CompositeOperandTypeChecker)
94+
OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER)
95+
.or(OperandTypes.CHARACTER_CHARACTER));
5096
public static final UDFOperandMetadata DATETIME_OR_STRING_DATETIME_OR_STRING =
5197
UDFOperandMetadata.wrap(
5298
(CompositeOperandTypeChecker)
53-
OperandTypes.STRING_STRING
99+
OperandTypes.CHARACTER_CHARACTER
54100
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME))
55-
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.STRING))
56-
.or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.DATETIME)));
57-
public static final UDFOperandMetadata TIME_OR_TIMESTAMP_OR_STRING =
101+
.or(OperandTypes.family(SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
102+
.or(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME)));
103+
public static final UDFOperandMetadata STRING_TIMESTAMP =
104+
UDFOperandMetadata.wrap(
105+
OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP));
106+
public static final UDFOperandMetadata STRING_DATETIME =
107+
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME));
108+
public static final UDFOperandMetadata DATETIME_INTERVAL =
109+
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.DATETIME_INTERVAL);
110+
public static final UDFOperandMetadata TIME_TIME =
111+
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.TIME, SqlTypeFamily.TIME));
112+
113+
public static final UDFOperandMetadata TIMESTAMP_OR_STRING_STRING_STRING =
58114
UDFOperandMetadata.wrap(
59115
(CompositeOperandTypeChecker)
60-
OperandTypes.STRING.or(OperandTypes.TIME).or(OperandTypes.TIMESTAMP));
61-
public static final UDFOperandMetadata DATE_OR_TIMESTAMP_OR_STRING =
116+
OperandTypes.family(
117+
SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
118+
.or(
119+
OperandTypes.family(
120+
SqlTypeFamily.CHARACTER,
121+
SqlTypeFamily.CHARACTER,
122+
SqlTypeFamily.CHARACTER)));
123+
public static final UDFOperandMetadata STRING_INTEGER_DATETIME_OR_STRING =
62124
UDFOperandMetadata.wrap(
63-
(CompositeOperandTypeChecker) OperandTypes.DATE_OR_TIMESTAMP.or(OperandTypes.STRING));
64-
public static final UDFOperandMetadata STRING_TIMESTAMP =
65-
UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.TIMESTAMP));
125+
(CompositeOperandTypeChecker)
126+
OperandTypes.family(
127+
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
128+
.or(
129+
OperandTypes.family(
130+
SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME)));
131+
public static final UDFOperandMetadata INTERVAL_DATETIME_DATETIME =
132+
UDFOperandMetadata.wrap(
133+
(CompositeOperandTypeChecker)
134+
OperandTypes.family(
135+
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.DATETIME)
136+
.or(
137+
OperandTypes.family(
138+
SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME))
139+
.or(
140+
OperandTypes.family(
141+
SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME, SqlTypeFamily.CHARACTER))
142+
.or(
143+
OperandTypes.family(
144+
SqlTypeFamily.CHARACTER,
145+
SqlTypeFamily.CHARACTER,
146+
SqlTypeFamily.CHARACTER)));
66147
}

core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Locale;
1010
import java.util.Map;
1111
import java.util.Optional;
12+
import java.util.Set;
1213
import lombok.AllArgsConstructor;
1314
import lombok.Getter;
1415
import lombok.RequiredArgsConstructor;
@@ -380,4 +381,13 @@ public static Optional<BuiltinFunctionName> ofWindowFunction(String functionName
380381
return Optional.ofNullable(
381382
WINDOW_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
382383
}
384+
385+
public static final Set<BuiltinFunctionName> COMPARATORS =
386+
Set.of(
387+
BuiltinFunctionName.EQUAL,
388+
BuiltinFunctionName.NOTEQUAL,
389+
BuiltinFunctionName.LESS,
390+
BuiltinFunctionName.LTE,
391+
BuiltinFunctionName.GREATER,
392+
BuiltinFunctionName.GTE);
383393
}

core/src/main/java/org/opensearch/sql/expression/function/CalciteFuncSignature.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ public PPLTypeChecker getTypeChecker() {
2727
return typeChecker;
2828
}
2929

30-
public boolean match(FunctionName functionName, List<RelDataType> paramTypeList) {
30+
public boolean match(FunctionName functionName, List<RelDataType> argTypes) {
3131
if (!functionName.equals(this.functionName)) return false;
3232
// For complex type checkers (e.g., OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED),
3333
// the typeChecker will be null because only simple family-based type checks are currently
3434
// supported.
3535
if (typeChecker == null) return true;
36-
return typeChecker.checkOperandTypes(paramTypeList);
36+
return typeChecker.checkOperandTypes(argTypes);
3737
}
3838
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.expression.function;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import java.util.stream.Collectors;
11+
import javax.annotation.Nullable;
12+
import org.apache.calcite.rex.RexBuilder;
13+
import org.apache.calcite.rex.RexNode;
14+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
15+
import org.opensearch.sql.data.type.ExprCoreType;
16+
import org.opensearch.sql.data.type.ExprType;
17+
import org.opensearch.sql.data.type.WideningTypeRule;
18+
import org.opensearch.sql.exception.ExpressionEvaluationException;
19+
20+
public class CoercionUtils {
21+
22+
/**
23+
* Casts the arguments to the types specified in the typeChecker. Returns null if no combination
24+
* of parameter types matches the arguments or if casting fails.
25+
*
26+
* @param builder RexBuilder to create casts
27+
* @param typeChecker PPLTypeChecker that provides the parameter types
28+
* @param arguments List of RexNode arguments to be cast
29+
* @return List of cast RexNode arguments or null if casting fails
30+
*/
31+
public static @Nullable List<RexNode> castArguments(
32+
RexBuilder builder, PPLTypeChecker typeChecker, List<RexNode> arguments) {
33+
List<List<ExprType>> paramTypeCombinations = typeChecker.getParameterTypes();
34+
35+
// TODO: var args?
36+
37+
for (List<ExprType> paramTypes : paramTypeCombinations) {
38+
List<RexNode> castedArguments = castArguments(builder, paramTypes, arguments);
39+
if (castedArguments != null) {
40+
return castedArguments;
41+
}
42+
}
43+
return null;
44+
}
45+
46+
/**
47+
* Widen the arguments to the widest type found among them. If no widest type can be determined,
48+
* returns null.
49+
*
50+
* @param builder RexBuilder to create casts
51+
* @param arguments List of RexNode arguments to be widened
52+
* @return List of widened RexNode arguments or null if no widest type can be determined
53+
*/
54+
public static @Nullable List<RexNode> widenArguments(
55+
RexBuilder builder, List<RexNode> arguments) {
56+
// TODO: Add test on e.g. IP
57+
ExprType widestType = findWidestType(arguments);
58+
if (widestType == null) {
59+
return null; // No widest type found, return null
60+
}
61+
return arguments.stream().map(arg -> cast(builder, widestType, arg)).collect(Collectors.toList());
62+
}
63+
64+
/**
65+
* Casts the arguments to the types specified in paramTypes. Returns null if the number of
66+
* parameters does not match or if casting fails.
67+
*/
68+
private static @Nullable List<RexNode> castArguments(
69+
RexBuilder builder, List<ExprType> paramTypes, List<RexNode> arguments) {
70+
if (paramTypes.size() != arguments.size()) {
71+
return null; // Skip if the number of parameters does not match
72+
}
73+
74+
List<RexNode> castedArguments = new ArrayList<>();
75+
for (int i = 0; i < paramTypes.size(); i++) {
76+
ExprType toType = paramTypes.get(i);
77+
RexNode arg = arguments.get(i);
78+
79+
RexNode castedArg = cast(builder, toType, arg);
80+
81+
if (castedArg == null) {
82+
return null;
83+
}
84+
castedArguments.add(castedArg);
85+
}
86+
return castedArguments;
87+
}
88+
89+
private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) {
90+
ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType());
91+
if (!argType.shouldCast(targetType)) {
92+
return arg;
93+
}
94+
95+
if (WideningTypeRule.distance(argType, targetType) != WideningTypeRule.IMPOSSIBLE_WIDENING) {
96+
return builder.makeCast(OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg);
97+
}
98+
return null;
99+
}
100+
101+
/**
102+
* Finds the widest type among the given arguments. The widest type is determined by applying the
103+
* widening type rule to each pair of types in the arguments.
104+
*
105+
* @param arguments List of RexNode arguments to find the widest type from
106+
* @return the widest ExprType if found, otherwise null
107+
*/
108+
private static @Nullable ExprType findWidestType(List<RexNode> arguments) {
109+
if (arguments.isEmpty()) {
110+
return null; // No arguments to process
111+
}
112+
ExprType widestType =
113+
OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(0).getType());
114+
if (arguments.size() == 1) {
115+
return widestType;
116+
}
117+
118+
// Iterate pairwise through the arguments and find the widest type
119+
for (int i = 1; i < arguments.size(); i++) {
120+
var type = OpenSearchTypeFactory.convertRelDataTypeToExprType(arguments.get(i).getType());
121+
try {
122+
if (areDateAndTime(widestType, type)) {
123+
// If one is date and the other is time, we consider timestamp as the widest type
124+
widestType = ExprCoreType.TIMESTAMP;
125+
} else {
126+
widestType = WideningTypeRule.max(widestType, type);
127+
}
128+
} catch (ExpressionEvaluationException e) {
129+
// the two types are not compatible, return null
130+
return null;
131+
}
132+
}
133+
return widestType;
134+
}
135+
136+
private static boolean areDateAndTime(ExprType type1, ExprType type2) {
137+
return (type1 == ExprCoreType.DATE && type2 == ExprCoreType.TIME)
138+
|| (type1 == ExprCoreType.TIME && type2 == ExprCoreType.DATE);
139+
}
140+
}

0 commit comments

Comments
 (0)