Skip to content

Commit d8d2763

Browse files
committed
Reimplement atan, sqrt, strcmp, xor with SqlCall rewrite
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent e2d631d commit d8d2763

3 files changed

Lines changed: 54 additions & 15 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/validate/PplOpTable.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ private BuiltinFunctionName sqlFunctionNameToPplFunctionName(String name) {
101101
case "CONVERT" -> BuiltinFunctionName.CONV;
102102
case "ILIKE" -> BuiltinFunctionName.LIKE;
103103
case "CHAR_LENGTH" -> BuiltinFunctionName.LENGTH;
104+
case "NOT_EQUALS", "<>" -> BuiltinFunctionName.XOR;
104105
default -> BuiltinFunctionName.of(name).orElse(null);
105106
};
106107
}

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public SqlNode rewriteCall(SqlValidator validator, SqlCall call) {
130130
return super.rewriteCall(validator, call);
131131
}
132132
};
133-
public static final SqlOperator ATAN =
133+
public static final SqlFunction ATAN =
134134
new SqlFunction(
135135
"ATAN",
136136
SqlKind.OTHER_FUNCTION,
@@ -149,6 +149,24 @@ public SqlNode rewriteCall(SqlValidator validator, SqlCall call) {
149149
}
150150
};
151151

152+
public static final SqlFunction SQRT =
153+
new SqlFunction(
154+
"SQRT",
155+
SqlKind.OTHER_FUNCTION,
156+
ReturnTypes.DOUBLE_NULLABLE,
157+
null,
158+
OperandTypes.NUMERIC,
159+
SqlFunctionCategory.USER_DEFINED_FUNCTION) {
160+
@Override
161+
public SqlNode rewriteCall(SqlValidator validator, SqlCall call) {
162+
// Rewrite SQRT(x) to POWER(x, 0.5)
163+
return SqlStdOperatorTable.POWER.createCall(
164+
call.getParserPosition(),
165+
call.operand(0),
166+
SqlLiteral.createExactNumeric("0.5", call.getParserPosition()));
167+
}
168+
};
169+
152170
// String functions
153171
public static final SqlFunction TRIM =
154172
new SqlFunction(
@@ -235,6 +253,21 @@ public SqlNode rewriteCall(SqlValidator validator, SqlCall call) {
235253
// Condition function
236254
public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST");
237255
public static final SqlOperator LATEST = new LatestFunction().toUDF("LATEST");
256+
public static final SqlFunction XOR =
257+
new SqlFunction(
258+
"XOR",
259+
SqlKind.OTHER_FUNCTION,
260+
ReturnTypes.BOOLEAN_NULLABLE,
261+
null,
262+
OperandTypes.BOOLEAN_BOOLEAN,
263+
SqlFunctionCategory.USER_DEFINED_FUNCTION) {
264+
@Override
265+
public SqlNode rewriteCall(SqlValidator validator, SqlCall call) {
266+
// Rewrite XOR(x, y) to NOT_EQUALS(x, y)
267+
return SqlStdOperatorTable.NOT_EQUALS.createCall(
268+
call.getParserPosition(), call.operand(0), call.operand(1));
269+
}
270+
};
238271

239272
// Datetime function
240273
public static final SqlOperator TIMESTAMP = new TimestampFunction().toUDF("TIMESTAMP");

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@
214214

215215
import com.google.common.collect.ImmutableList;
216216
import com.google.common.collect.ImmutableMap;
217-
import java.math.BigDecimal;
218217
import java.util.ArrayList;
219218
import java.util.Arrays;
220219
import java.util.HashMap;
@@ -332,6 +331,18 @@ default PPLTypeChecker getTypeChecker() {
332331
final AggBuilder aggBuilder = new AggBuilder();
333332
aggBuilder.populate();
334333
INSTANCE = new PPLFuncImpTable(builder, aggBuilder);
334+
335+
// Some operators are registered via register instead of registerOperator
336+
// We add them explicitly so that they can be found during validation
337+
var pplOps = PplOpTable.getInstance();
338+
pplOps.add(JSON_ARRAY, SqlStdOperatorTable.JSON_ARRAY);
339+
pplOps.add(JSON_OBJECT, SqlStdOperatorTable.JSON_OBJECT);
340+
pplOps.add(INTERNAL_ITEM, SqlStdOperatorTable.ITEM);
341+
// pplOps.add(TYPEOF, ... );
342+
pplOps.add(IF, SqlStdOperatorTable.CASE);
343+
pplOps.add(NULLIF, SqlStdOperatorTable.CASE);
344+
pplOps.add(IS_EMPTY, SqlStdOperatorTable.IS_EMPTY);
345+
pplOps.add(IS_BLANK, SqlStdOperatorTable.IS_EMPTY);
335346
}
336347

337348
/**
@@ -535,7 +546,9 @@ private static void registerToCatalogWithReplace(
535546
TRIM,
536547
SqlStdOperatorTable.TRIM,
537548
STRCMP,
538-
SqlLibraryOperators.STRCMP);
549+
SqlLibraryOperators.STRCMP,
550+
XOR,
551+
SqlStdOperatorTable.NOT_EQUALS);
539552
PplOpTable.getInstance().add(functionName, replacement.getOrDefault(functionName, operator));
540553
}
541554

@@ -720,6 +733,9 @@ void populate() {
720733
registerOperator(CRC32, PPLBuiltinOperators.CRC32);
721734
registerOperator(DIVIDE, PPLBuiltinOperators.DIVIDE);
722735
registerOperator(DIVIDEFUNCTION, PPLBuiltinOperators.DIVIDE);
736+
// SqlStdOperatorTable.SQRT is declared but not implemented. The call to SQRT in Calcite is
737+
// converted to POWER(x, 0.5).
738+
registerOperator(SQRT, PPLBuiltinOperators.SQRT);
723739
registerOperator(SHA2, PPLBuiltinOperators.SHA2);
724740
registerOperator(CIDRMATCH, PPLBuiltinOperators.CIDRMATCH);
725741
registerOperator(INTERNAL_GROK, PPLBuiltinOperators.GROK);
@@ -860,23 +876,12 @@ void populate() {
860876
OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER)
861877
.or(OperandTypes.family(SqlTypeFamily.MAP, SqlTypeFamily.ANY)),
862878
false));
863-
// SqlStdOperatorTable.SQRT is declared but not implemented. The call to SQRT in Calcite is
864-
// converted to POWER(x, 0.5).
865-
register(
866-
SQRT,
867-
createFunctionImpWithTypeChecker(
868-
(builder, arg) ->
869-
builder.makeCall(
870-
SqlStdOperatorTable.POWER,
871-
arg,
872-
builder.makeApproxLiteral(BigDecimal.valueOf(0.5))),
873-
PPLTypeChecker.family(SqlTypeFamily.NUMERIC)));
874879
register(
875880
TYPEOF,
876881
(FunctionImp1)
877882
(builder, arg) ->
878883
builder.makeLiteral(getLegacyTypeName(arg.getType(), QueryType.PPL)));
879-
register(XOR, new XOR_FUNC());
884+
registerOperator(XOR, PPLBuiltinOperators.XOR);
880885
// SqlStdOperatorTable.CASE.getOperandTypeChecker is null. We manually create a type checker
881886
// for it. The second and third operands are required to be of the same type. If not,
882887
// it will throw an IllegalArgumentException with information Can't find leastRestrictive type

0 commit comments

Comments
 (0)