Skip to content

Commit 8ab0c89

Browse files
committed
Add reverse op for compare ip to support pushdown
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 484f49e commit 8ab0c89

7 files changed

Lines changed: 131 additions & 18 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/PPLReturnTypes.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ private PPLReturnTypes() {}
2424
ReturnTypes.explicit(UserDefinedFunctionUtils.NULLABLE_TIME_UDT);
2525
public static final SqlReturnTypeInference TIMESTAMP_FORCE_NULLABLE =
2626
ReturnTypes.explicit(UserDefinedFunctionUtils.NULLABLE_TIMESTAMP_UDT);
27+
public static final SqlReturnTypeInference IP_FORCE_NULLABLE =
28+
ReturnTypes.explicit(UserDefinedFunctionUtils.NULLABLE_IP_UDT);
2729
public static SqlReturnTypeInference INTEGER_FORCE_NULLABLE =
2830
ReturnTypes.INTEGER.andThen(SqlTypeTransforms.FORCE_NULLABLE);
2931
public static SqlReturnTypeInference STRING_FORCE_NULLABLE =

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public class UserDefinedFunctionUtils {
5454
TYPE_FACTORY.createUDT(ExprUDT.EXPR_TIMESTAMP, true);
5555
public static final RelDataType NULLABLE_STRING =
5656
TYPE_FACTORY.createTypeWithNullability(TYPE_FACTORY.createSqlType(SqlTypeName.VARCHAR), true);
57+
public static final RelDataType NULLABLE_IP_UDT = TYPE_FACTORY.createUDT(EXPR_IP, true);
5758

5859
public static RelDataType nullablePatternAggList =
5960
createArrayType(

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.lang.reflect.InvocationTargetException;
1313
import java.lang.reflect.Method;
1414
import java.util.List;
15+
import java.util.concurrent.atomic.AtomicReference;
1516
import java.util.function.Supplier;
1617
import org.apache.calcite.adapter.enumerable.NullPolicy;
1718
import org.apache.calcite.adapter.enumerable.RexImpTable;
@@ -109,12 +110,19 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
109110

110111
// IP comparing functions
111112
public static final SqlOperator NOT_EQUALS_IP =
112-
CompareIpFunction.notEquals().toUDF("NOT_EQUALS_IP");
113-
public static final SqlOperator EQUALS_IP = CompareIpFunction.equals().toUDF("EQUALS_IP");
114-
public static final SqlOperator GREATER_IP = CompareIpFunction.greater().toUDF("GREATER_IP");
115-
public static final SqlOperator GTE_IP = CompareIpFunction.greaterOrEquals().toUDF("GTE_IP");
116-
public static final SqlOperator LESS_IP = CompareIpFunction.less().toUDF("LESS_IP");
117-
public static final SqlOperator LTE_IP = CompareIpFunction.lessOrEquals().toUDF("LTE_IP");
113+
CompareIpFunction.notEquals()
114+
.withReverse(lookupOperator("NOT_EQUALS_IP"))
115+
.toUDF("NOT_EQUALS_IP");
116+
public static final SqlOperator EQUALS_IP =
117+
CompareIpFunction.equals().withReverse(lookupOperator("EQUALS_IP")).toUDF("EQUALS_IP");
118+
public static final SqlOperator GREATER_IP =
119+
CompareIpFunction.greater().withReverse(lookupOperator("LESS_IP")).toUDF("GREATER_IP");
120+
public static final SqlOperator GTE_IP =
121+
CompareIpFunction.greaterOrEquals().withReverse(lookupOperator("LTE_IP")).toUDF("GTE_IP");
122+
public static final SqlOperator LESS_IP =
123+
CompareIpFunction.less().withReverse(lookupOperator("GREATER_IP")).toUDF("LESS_IP");
124+
public static final SqlOperator LTE_IP =
125+
CompareIpFunction.lessOrEquals().withReverse(lookupOperator("GTE_IP")).toUDF("LTE_IP");
118126

119127
// Condition function
120128
public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST");
@@ -382,4 +390,12 @@ private static Expression invokeCalciteImplementor(
382390
method.setAccessible(true);
383391
return (Expression) method.invoke(rexCallImplementor, translator, call, List.of(field));
384392
}
393+
394+
private static Supplier<SqlOperator> lookupOperator(String name) {
395+
return () -> {
396+
AtomicReference<SqlOperator> ref = new AtomicReference<>();
397+
INSTANCE.get().lookUpOperators(name, false, ref::set);
398+
return ref.get();
399+
};
400+
}
385401
}

core/src/main/java/org/opensearch/sql/expression/function/UserDefinedFunctionBuilder.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
package org.opensearch.sql.expression.function;
77

88
import java.util.Collections;
9+
import java.util.function.Supplier;
910
import org.apache.calcite.schema.ImplementableFunction;
1011
import org.apache.calcite.sql.SqlIdentifier;
1112
import org.apache.calcite.sql.SqlKind;
13+
import org.apache.calcite.sql.SqlOperator;
14+
import org.apache.calcite.sql.SqlSyntax;
1215
import org.apache.calcite.sql.parser.SqlParserPos;
1316
import org.apache.calcite.sql.type.InferTypes;
1417
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1518
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
19+
import org.checkerframework.checker.nullness.qual.Nullable;
1620

1721
/**
1822
* The interface helps to construct a SqlUserDefinedFunction
@@ -32,6 +36,18 @@ public interface UserDefinedFunctionBuilder {
3236

3337
UDFOperandMetadata getOperandMetadata();
3438

39+
default SqlKind getKind() {
40+
return SqlKind.OTHER_FUNCTION;
41+
}
42+
43+
default SqlSyntax getSqlSyntax() {
44+
return SqlSyntax.FUNCTION;
45+
}
46+
47+
default Supplier<SqlOperator> getReverse() {
48+
return null;
49+
}
50+
3551
default SqlUserDefinedFunction toUDF(String functionName) {
3652
return toUDF(functionName, true);
3753
}
@@ -50,7 +66,7 @@ default SqlUserDefinedFunction toUDF(String functionName, boolean isDeterministi
5066
new SqlIdentifier(Collections.singletonList(functionName), null, SqlParserPos.ZERO, null);
5167
return new SqlUserDefinedFunction(
5268
udfLtrimIdentifier,
53-
SqlKind.OTHER_FUNCTION,
69+
getKind(),
5470
getReturnTypeInference(),
5571
InferTypes.ANY_NULLABLE,
5672
getOperandMetadata(),
@@ -59,6 +75,19 @@ default SqlUserDefinedFunction toUDF(String functionName, boolean isDeterministi
5975
public boolean isDeterministic() {
6076
return isDeterministic;
6177
}
78+
79+
@Override
80+
public @Nullable SqlOperator reverse() {
81+
if (getReverse() == null) {
82+
return null;
83+
}
84+
return getReverse().get();
85+
}
86+
87+
@Override
88+
public SqlSyntax getSyntax() {
89+
return getSqlSyntax();
90+
}
6291
};
6392
}
6493
}

core/src/main/java/org/opensearch/sql/expression/function/udf/ip/CompareIpFunction.java

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
package org.opensearch.sql.expression.function.udf.ip;
77

88
import java.util.List;
9+
import java.util.function.Supplier;
910
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
1011
import org.apache.calcite.adapter.enumerable.NullPolicy;
1112
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
1213
import org.apache.calcite.linq4j.tree.ConstantExpression;
1314
import org.apache.calcite.linq4j.tree.Expression;
1415
import org.apache.calcite.linq4j.tree.Expressions;
1516
import org.apache.calcite.rex.RexCall;
17+
import org.apache.calcite.sql.SqlKind;
18+
import org.apache.calcite.sql.SqlOperator;
19+
import org.apache.calcite.sql.SqlSyntax;
1620
import org.apache.calcite.sql.type.ReturnTypes;
1721
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1822
import org.opensearch.sql.data.model.ExprIpValue;
@@ -32,33 +36,80 @@
3236
* </ul>
3337
*/
3438
public class CompareIpFunction extends ImplementorUDF {
39+
private Supplier<SqlOperator> reverse;
3540

3641
private CompareIpFunction(ComparisonType comparisonType) {
3742
super(new CompareImplementor(comparisonType), NullPolicy.ANY);
43+
reverse = null;
44+
}
45+
46+
@Override
47+
public SqlSyntax getSqlSyntax() {
48+
return SqlSyntax.BINARY;
3849
}
3950

4051
public static CompareIpFunction less() {
41-
return new CompareIpFunction(ComparisonType.LESS);
52+
return new CompareIpFunction(ComparisonType.LESS) {
53+
@Override
54+
public SqlKind getKind() {
55+
return SqlKind.LESS_THAN;
56+
}
57+
};
58+
}
59+
60+
@Override
61+
public Supplier<SqlOperator> getReverse() {
62+
return reverse;
63+
}
64+
65+
public CompareIpFunction withReverse(Supplier<SqlOperator> supplier) {
66+
this.reverse = supplier;
67+
return this;
4268
}
4369

4470
public static CompareIpFunction greater() {
45-
return new CompareIpFunction(ComparisonType.GREATER);
71+
return new CompareIpFunction(ComparisonType.GREATER) {
72+
@Override
73+
public SqlKind getKind() {
74+
return SqlKind.GREATER_THAN;
75+
}
76+
};
4677
}
4778

4879
public static CompareIpFunction lessOrEquals() {
49-
return new CompareIpFunction(ComparisonType.LESS_OR_EQUAL);
80+
return new CompareIpFunction(ComparisonType.LESS_OR_EQUAL) {
81+
@Override
82+
public SqlKind getKind() {
83+
return SqlKind.LESS_THAN_OR_EQUAL;
84+
}
85+
};
5086
}
5187

5288
public static CompareIpFunction greaterOrEquals() {
53-
return new CompareIpFunction(ComparisonType.GREATER_OR_EQUAL);
89+
return new CompareIpFunction(ComparisonType.GREATER_OR_EQUAL) {
90+
@Override
91+
public SqlKind getKind() {
92+
return SqlKind.GREATER_THAN_OR_EQUAL;
93+
}
94+
};
5495
}
5596

5697
public static CompareIpFunction equals() {
57-
return new CompareIpFunction(ComparisonType.EQUALS);
98+
return new CompareIpFunction(ComparisonType.EQUALS) {
99+
@Override
100+
public SqlKind getKind() {
101+
return SqlKind.EQUALS;
102+
}
103+
};
58104
}
59105

60106
public static CompareIpFunction notEquals() {
61-
return new CompareIpFunction(ComparisonType.NOT_EQUALS);
107+
return new CompareIpFunction(ComparisonType.NOT_EQUALS) {
108+
@Override
109+
public SqlKind getKind() {
110+
return SqlKind.NOT_EQUALS;
111+
}
112+
};
62113
}
63114

64115
@Override
@@ -88,10 +139,10 @@ public Expression implement(
88139
translatedOperands.get(0),
89140
translatedOperands.get(1));
90141

91-
return generateComparisonExpression(compareResult, comparisonType);
142+
return evalCompareResult(compareResult, comparisonType);
92143
}
93144

94-
private static Expression generateComparisonExpression(
145+
private static Expression evalCompareResult(
95146
Expression compareResult, ComparisonType comparisonType) {
96147
final ConstantExpression zero = Expressions.constant(0);
97148
return switch (comparisonType) {

core/src/main/java/org/opensearch/sql/expression/function/udf/ip/IPFunction.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import org.apache.calcite.linq4j.tree.Expression;
1313
import org.apache.calcite.linq4j.tree.Expressions;
1414
import org.apache.calcite.rex.RexCall;
15-
import org.apache.calcite.sql.type.ReturnTypes;
1615
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1716
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
17+
import org.opensearch.sql.calcite.utils.PPLReturnTypes;
1818
import org.opensearch.sql.data.model.ExprIpValue;
1919
import org.opensearch.sql.data.type.ExprCoreType;
2020
import org.opensearch.sql.data.type.ExprType;
@@ -46,8 +46,7 @@ public UDFOperandMetadata getOperandMetadata() {
4646

4747
@Override
4848
public SqlReturnTypeInference getReturnTypeInference() {
49-
return ReturnTypes.explicit(
50-
OpenSearchTypeFactory.TYPE_FACTORY.createUDT(OpenSearchTypeFactory.ExprUDT.EXPR_IP, true));
49+
return PPLReturnTypes.IP_FORCE_NULLABLE;
5150
}
5251

5352
public static class CastImplementor

opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@
8484
import org.opensearch.index.query.ScriptQueryBuilder;
8585
import org.opensearch.script.Script;
8686
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
87+
import org.opensearch.sql.calcite.type.ExprIPType;
8788
import org.opensearch.sql.calcite.type.ExprSqlType;
8889
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT;
90+
import org.opensearch.sql.data.model.ExprIpValue;
8991
import org.opensearch.sql.data.model.ExprTimestampValue;
9092
import org.opensearch.sql.data.type.ExprCoreType;
9193
import org.opensearch.sql.data.type.ExprType;
@@ -335,6 +337,8 @@ public Expression visitCall(RexCall call) {
335337
}
336338
};
337339
case FUNCTION:
340+
if (call.getOperator().getName().equalsIgnoreCase("IP")) {}
341+
338342
return visitRelevanceFunc(call);
339343
default:
340344
String message =
@@ -1348,6 +1352,11 @@ private static String timestampValueForPushDown(String value) {
13481352
// https://github.com/opensearch-project/sql/pull/3442
13491353
}
13501354

1355+
private static String ipValueForPushDown(String value) {
1356+
ExprIpValue exprIpValue = new ExprIpValue(value);
1357+
return exprIpValue.value();
1358+
}
1359+
13511360
public static class ScriptQueryExpression extends QueryExpression {
13521361
private final String code;
13531362
private RexNode analyzedNode;
@@ -1539,6 +1548,8 @@ Object value() {
15391548
return timestampValueForPushDown(RexLiteral.stringValue(literal));
15401549
} else if (isString()) {
15411550
return RexLiteral.stringValue(literal);
1551+
} else if (isIp()) {
1552+
return ipValueForPushDown(RexLiteral.stringValue(literal));
15421553
} else {
15431554
return rawValue();
15441555
}
@@ -1575,6 +1586,10 @@ public boolean isTimestamp() {
15751586
return false;
15761587
}
15771588

1589+
public boolean isIp() {
1590+
return literal.getType() instanceof ExprIPType;
1591+
}
1592+
15781593
long longValue() {
15791594
return ((Number) literal.getValue()).longValue();
15801595
}

0 commit comments

Comments
 (0)