Skip to content

Commit a7f252f

Browse files
committed
Implement composite type checker
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent e639dcf commit a7f252f

9 files changed

Lines changed: 200 additions & 93 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
import org.opensearch.sql.calcite.udf.datetimeUDF.StrToDateFunction;
6666
import org.opensearch.sql.calcite.udf.datetimeUDF.SysdateFunction;
6767
import org.opensearch.sql.calcite.udf.datetimeUDF.TimeAddSubFunction;
68-
import org.opensearch.sql.calcite.udf.datetimeUDF.TimeDiffFunction;
6968
import org.opensearch.sql.calcite.udf.datetimeUDF.TimeFormatFunction;
7069
import org.opensearch.sql.calcite.udf.datetimeUDF.TimeFunction;
7170
import org.opensearch.sql.calcite.udf.datetimeUDF.TimeToSecondFunction;
@@ -231,8 +230,6 @@ static SqlOperator translate(String op) {
231230
return TransferUserDefinedFunction(SysdateFunction.class, "SYSDATE", timestampInference);
232231
case "TIME":
233232
return TransferUserDefinedFunction(TimeFunction.class, "TIME", timeInference);
234-
case "TIMEDIFF":
235-
return TransferUserDefinedFunction(TimeDiffFunction.class, "TIMEDIFF", timeInference);
236233
case "TIME_TO_SEC":
237234
return TransferUserDefinedFunction(
238235
TimeToSecondFunction.class, "TIME_TO_SEC", ReturnTypes.BIGINT);

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import java.util.Objects;
2525
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
2626
import org.apache.calcite.adapter.enumerable.NullPolicy;
27-
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
2827
import org.apache.calcite.linq4j.tree.Expression;
2928
import org.apache.calcite.linq4j.tree.Expressions;
3029
import org.apache.calcite.linq4j.tree.Types;
@@ -55,8 +54,7 @@
5554
import org.opensearch.sql.executor.QueryType;
5655
import org.opensearch.sql.expression.function.FunctionProperties;
5756
import org.opensearch.sql.expression.function.ImplementorUDF;
58-
import org.opensearch.sql.expression.function.PPLTypeChecker;
59-
import org.opensearch.sql.expression.function.UDFTypeChecker;
57+
import org.opensearch.sql.expression.function.UDFOperandMetadata;
6058

6159
public class UserDefinedFunctionUtils {
6260
public static SqlReturnTypeInference INTEGER_FORCE_NULLABLE =
@@ -293,30 +291,18 @@ public static List<Expression> convertToExprValues(
293291
return exprValues;
294292
}
295293

296-
private static List<Expression> prependTimestampAsProperty(
297-
List<Expression> operands, RexToLixTranslator translator) {
298-
List<Expression> operandsWithProperties = new ArrayList<>(operands);
299-
Expression properties =
300-
Expressions.call(
301-
UserDefinedFunctionUtils.class, "restoreFunctionProperties", translator.getRoot());
302-
operandsWithProperties.addFirst(properties);
303-
return Collections.unmodifiableList(operandsWithProperties);
304-
}
305-
306-
public static ImplementorUDF adaptExprMethodWithPropertiesToUDF(
294+
public static ImplementorUDF adaptExprMethodToUDF(
307295
java.lang.reflect.Type type,
308296
String methodName,
309297
SqlReturnTypeInference returnTypeInference,
310298
NullPolicy nullPolicy,
311-
UDFTypeChecker typeChecker) {
299+
UDFOperandMetadata typeChecker) {
312300
NotNullImplementor implementor =
313301
(translator, call, translatedOperands) -> {
314302
List<Expression> operands =
315303
convertToExprValues(
316304
translatedOperands, call.getOperands().stream().map(RexNode::getType).toList());
317-
List<Expression> operandsWithProperties =
318-
prependTimestampAsProperty(operands, translator);
319-
Expression exprResult = Expressions.call(type, methodName, operandsWithProperties);
305+
Expression exprResult = Expressions.call(type, methodName, operands);
320306
return Expressions.call(exprResult, "valueForCalcite");
321307
};
322308
return new ImplementorUDF(implementor, nullPolicy) {
@@ -326,44 +312,7 @@ public SqlReturnTypeInference getReturnTypeInference() {
326312
}
327313

328314
@Override
329-
public UDFTypeChecker getOperandTypeChecker() {
330-
return typeChecker;
331-
}
332-
333-
};
334-
}
335-
336-
public static ImplementorUDF adaptExprMethodWithPropertiesToUDF(
337-
java.lang.reflect.Type type,
338-
String methodName,
339-
SqlReturnTypeInference returnTypeInference,
340-
NullPolicy nullPolicy) {
341-
return adaptExprMethodWithPropertiesToUDF(
342-
type, methodName, returnTypeInference, nullPolicy, null);
343-
}
344-
345-
public static ImplementorUDF adaptExprMethodToUDF(
346-
java.lang.reflect.Type type,
347-
String methodName,
348-
SqlReturnTypeInference returnTypeInference,
349-
NullPolicy nullPolicy,
350-
UDFTypeChecker typeChecker) {
351-
NotNullImplementor implementor =
352-
(translator, call, translatedOperands) -> {
353-
List<Expression> operands =
354-
convertToExprValues(
355-
translatedOperands, call.getOperands().stream().map(RexNode::getType).toList());
356-
Expression exprResult = Expressions.call(type, methodName, operands);
357-
return Expressions.call(exprResult, "valueForCalcite");
358-
};
359-
return new ImplementorUDF(implementor, nullPolicy) {
360-
@Override
361-
public SqlReturnTypeInference getReturnTypeInference() {
362-
return returnTypeInference;
363-
}
364-
365-
@Override
366-
public UDFTypeChecker getOperandTypeChecker() {
315+
public UDFOperandMetadata getOperandMetadata() {
367316
return typeChecker;
368317
}
369318
};

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,38 @@
88
import java.lang.reflect.InvocationTargetException;
99
import java.lang.reflect.Method;
1010
import java.util.List;
11+
import org.apache.calcite.adapter.enumerable.NullPolicy;
1112
import org.apache.calcite.adapter.enumerable.RexImpTable;
1213
import org.apache.calcite.adapter.enumerable.RexImpTable.RexCallImplementor;
1314
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
1415
import org.apache.calcite.linq4j.tree.Expression;
1516
import org.apache.calcite.rex.RexCall;
1617
import org.apache.calcite.sql.SqlOperator;
18+
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
19+
import org.apache.calcite.sql.type.OperandTypes;
20+
import org.apache.calcite.sql.type.SqlTypeFamily;
1721
import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable;
1822
import org.apache.calcite.util.BuiltInMethod;
23+
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
24+
import org.opensearch.sql.expression.datetime.DateTimeFunctions;
1925

2026
/** Defines functions and operators that are implemented only by PPL */
2127
public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
2228

2329
public static final SqlOperator SPAN = new SpanFunctionImpl().toUDF("SPAN");
2430

31+
public static final SqlOperator TIMEDIFF =
32+
UserDefinedFunctionUtils.adaptExprMethodToUDF(
33+
DateTimeFunctions.class,
34+
"exprTimeDiff",
35+
UserDefinedFunctionUtils.timeInference,
36+
NullPolicy.ANY,
37+
UDFOperandMetadata.wrap(
38+
(CompositeOperandTypeChecker)
39+
OperandTypes.family(SqlTypeFamily.TIME, SqlTypeFamily.TIME)
40+
.or(OperandTypes.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING))))
41+
.toUDF("TIME_DIFF");
42+
2543
/**
2644
* Invoking an implementor registered in {@link RexImpTable}, need to use reflection since they're
2745
* all private Use method directly in {@link BuiltInMethod} if possible, most operators'

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@
6363
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBSTR;
6464
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBSTRING;
6565
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT;
66+
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TIMEDIFF;
6667
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM;
6768
import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF;
6869
import static org.opensearch.sql.expression.function.BuiltinFunctionName.UPPER;
6970
import static org.opensearch.sql.expression.function.BuiltinFunctionName.XOR;
71+
import static org.opensearch.sql.expression.function.PPLTypeChecker.compositeWrapper;
7072
import static org.opensearch.sql.expression.function.PPLTypeChecker.family;
7173
import static org.opensearch.sql.expression.function.PPLTypeChecker.familyWrapper;
7274

@@ -85,10 +87,12 @@
8587
import org.apache.calcite.sql.fun.SqlLibraryOperators;
8688
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
8789
import org.apache.calcite.sql.fun.SqlTrimFunction.Flag;
90+
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
8891
import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker;
8992
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
9093
import org.apache.calcite.sql.type.SqlTypeFamily;
9194
import org.apache.calcite.sql.type.SqlTypeName;
95+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
9296
import org.checkerframework.checker.nullness.qual.Nullable;
9397
import org.opensearch.sql.executor.QueryType;
9498

@@ -210,28 +214,63 @@ private abstract static class AbstractBuilder {
210214
abstract void register(BuiltinFunctionName functionName, FunctionImp functionImp);
211215

212216
void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) {
213-
SqlOperandTypeChecker typeChecker = operator.getOperandTypeChecker();
214-
if (typeChecker instanceof ImplicitCastOperandTypeChecker innerTypeChecker) {
215-
FunctionImp func =
216-
new FunctionImp() {
217-
@Override
218-
public RexNode resolve(RexBuilder builder, RexNode... args) {
219-
return builder.makeCall(operator, args);
220-
}
221-
222-
@Override
223-
public PPLTypeChecker getTypeChecker() {
224-
return familyWrapper(innerTypeChecker);
225-
}
226-
};
227-
register(functionName, func);
217+
SqlOperandTypeChecker typeChecker;
218+
if (operator instanceof SqlUserDefinedFunction udfOperator) {
219+
typeChecker = extractTypeCheckerFromUDF(udfOperator);
220+
} else {
221+
typeChecker = operator.getOperandTypeChecker();
222+
}
223+
224+
// Only the composite operand type checker for UDFs are concerned here.
225+
if (operator instanceof SqlUserDefinedFunction
226+
&& typeChecker instanceof CompositeOperandTypeChecker compositeTypeChecker) {
227+
register(functionName, createCompositeFunctionImp(operator, compositeTypeChecker));
228+
} else if (typeChecker instanceof ImplicitCastOperandTypeChecker implicitCastTypeChecker) {
229+
register(functionName, createImplicitCastFunctionImp(operator, implicitCastTypeChecker));
228230
} else {
229231
register(
230232
functionName,
231233
(RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node));
232234
}
233235
}
234236

237+
private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
238+
SqlUserDefinedFunction udfOperator) {
239+
UDFOperandMetadata udfOperandMetadata =
240+
(UDFOperandMetadata) udfOperator.getOperandTypeChecker();
241+
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
242+
}
243+
244+
private static FunctionImp createCompositeFunctionImp(
245+
SqlOperator operator, CompositeOperandTypeChecker typeChecker) {
246+
return new FunctionImp() {
247+
@Override
248+
public RexNode resolve(RexBuilder builder, RexNode... args) {
249+
return builder.makeCall(operator, args);
250+
}
251+
252+
@Override
253+
public PPLTypeChecker getTypeChecker() {
254+
return compositeWrapper(typeChecker);
255+
}
256+
};
257+
}
258+
259+
private static FunctionImp createImplicitCastFunctionImp(
260+
SqlOperator operator, ImplicitCastOperandTypeChecker typeChecker) {
261+
return new FunctionImp() {
262+
@Override
263+
public RexNode resolve(RexBuilder builder, RexNode... args) {
264+
return builder.makeCall(operator, args);
265+
}
266+
267+
@Override
268+
public PPLTypeChecker getTypeChecker() {
269+
return familyWrapper(typeChecker);
270+
}
271+
};
272+
}
273+
235274
void populate() {
236275
// Register std operator
237276
registerOperator(AND, SqlStdOperatorTable.AND);
@@ -292,6 +331,7 @@ void populate() {
292331

293332
// Register PPL UDF operator
294333
registerOperator(SPAN, PPLBuiltinOperators.SPAN);
334+
registerOperator(TIMEDIFF, PPLBuiltinOperators.TIMEDIFF);
295335

296336
// Register implementation.
297337
// Note, make the implementation an individual class if too complex.

core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import java.util.stream.Collectors;
1010
import java.util.stream.IntStream;
1111
import org.apache.calcite.rel.type.RelDataType;
12-
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
12+
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
1313
import org.apache.calcite.sql.type.ImplicitCastOperandTypeChecker;
1414
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
1515
import org.apache.calcite.sql.type.SqlTypeFamily;
@@ -25,7 +25,8 @@ private static boolean validateOperands(
2525
return true; // Skip checking if sizes do not match because some arguments may be optional
2626
}
2727
for (int i = 0; i < operandTypes.size(); i++) {
28-
SqlTypeName paramType = UserDefinedFunctionUtils.convertRelDataTypeToSqlTypeName(operandTypes.get(i));
28+
SqlTypeName paramType =
29+
UserDefinedFunctionUtils.convertRelDataTypeToSqlTypeName(operandTypes.get(i));
2930
SqlTypeFamily funcTypeFamily = funcTypeFamilies.get(i);
3031
if (paramType.getFamily() == SqlTypeFamily.IGNORE || funcTypeFamily == SqlTypeFamily.IGNORE) {
3132
continue;
@@ -61,8 +62,7 @@ public PPLFamilyTypeCheckerWrapper(ImplicitCastOperandTypeChecker typeChecker) {
6162
@Override
6263
public boolean checkOperandTypes(List<RelDataType> types) {
6364
if (innerTypeChecker instanceof SqlOperandTypeChecker sqlOperandTypeChecker
64-
&& !sqlOperandTypeChecker.getOperandCountRange().isValidCount(types.size()))
65-
return false;
65+
&& !sqlOperandTypeChecker.getOperandCountRange().isValidCount(types.size())) return false;
6666
List<SqlTypeFamily> families =
6767
IntStream.range(0, types.size())
6868
.mapToObj(innerTypeChecker::getOperandSqlTypeFamily)
@@ -71,6 +71,39 @@ public boolean checkOperandTypes(List<RelDataType> types) {
7171
}
7272
}
7373

74+
/** Currently only support OR compositions of family type checkers. */
75+
class PPLCompositeTypeChecker implements PPLTypeChecker {
76+
private final List<? extends SqlOperandTypeChecker> allowedRules;
77+
78+
public PPLCompositeTypeChecker(CompositeOperandTypeChecker typeChecker) {
79+
allowedRules = typeChecker.getRules();
80+
}
81+
82+
private static boolean validateWithFamilyTypeChecker(
83+
SqlOperandTypeChecker checker, List<RelDataType> types) {
84+
if (checker instanceof ImplicitCastOperandTypeChecker familyTypeChecker) {
85+
List<SqlTypeFamily> families =
86+
IntStream.range(0, types.size())
87+
.mapToObj(familyTypeChecker::getOperandSqlTypeFamily)
88+
.toList();
89+
return validateOperands(families, types);
90+
}
91+
throw new IllegalArgumentException(
92+
"Currently only compositions of ImplicitCastOperandTypeChecker are supported");
93+
}
94+
95+
@Override
96+
public boolean checkOperandTypes(List<RelDataType> types) {
97+
boolean operandCountValid =
98+
allowedRules.stream()
99+
.anyMatch(rule -> rule.getOperandCountRange().isValidCount(types.size()));
100+
if (!operandCountValid) {
101+
return false;
102+
}
103+
return allowedRules.stream().anyMatch(rule -> validateWithFamilyTypeChecker(rule, types));
104+
}
105+
}
106+
74107
/** Creates a checker that passes if each operand is a member of a corresponding family */
75108
static PPLFamilyTypeChecker family(SqlTypeFamily... families) {
76109
return new PPLFamilyTypeChecker(families);
@@ -79,4 +112,15 @@ static PPLFamilyTypeChecker family(SqlTypeFamily... families) {
79112
static PPLFamilyTypeCheckerWrapper familyWrapper(ImplicitCastOperandTypeChecker typeChecker) {
80113
return new PPLFamilyTypeCheckerWrapper(typeChecker);
81114
}
115+
116+
static PPLCompositeTypeChecker compositeWrapper(CompositeOperandTypeChecker typeChecker) {
117+
for (SqlOperandTypeChecker rule : typeChecker.getRules()) {
118+
if (!(rule instanceof ImplicitCastOperandTypeChecker)) {
119+
throw new IllegalArgumentException(
120+
"Currently only compositions of ImplicitCastOperandTypeChecker are supported, found:"
121+
+ rule.getClass().getName());
122+
}
123+
}
124+
return new PPLCompositeTypeChecker(typeChecker);
125+
}
82126
}

core/src/main/java/org/opensearch/sql/expression/function/SpanFunctionImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public SqlReturnTypeInference getReturnTypeInference() {
4343
}
4444

4545
@Override
46-
public UDFTypeChecker getOperandTypeChecker() {
46+
public UDFOperandMetadata getOperandMetadata() {
4747
// TODO: Implement a proper type checker for SPAN function
4848
return null;
4949
}

0 commit comments

Comments
 (0)