Skip to content

Commit 82bf1fe

Browse files
committed
AD operator implementation in Calcite
Signed-off-by: Songkan Tang <songkant@amazon.com>
1 parent c0f5680 commit 82bf1fe

19 files changed

Lines changed: 793 additions & 6 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
import java.util.stream.Stream;
4444
import lombok.AllArgsConstructor;
4545
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
46+
import org.apache.calcite.plan.Convention;
47+
import org.apache.calcite.plan.RelOptCluster;
4648
import org.apache.calcite.plan.RelOptTable;
4749
import org.apache.calcite.plan.ViewExpanders;
4850
import org.apache.calcite.rel.RelNode;
@@ -145,6 +147,7 @@
145147
import org.opensearch.sql.ast.tree.UnresolvedPlan;
146148
import org.opensearch.sql.ast.tree.Values;
147149
import org.opensearch.sql.ast.tree.Window;
150+
import org.opensearch.sql.calcite.plan.LogicalAD;
148151
import org.opensearch.sql.calcite.plan.LogicalSystemLimit;
149152
import org.opensearch.sql.calcite.plan.LogicalSystemLimit.SystemLimitType;
150153
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
@@ -2269,14 +2272,28 @@ private String findTimestampField(RelDataType rowType) {
22692272
return null;
22702273
}
22712274

2272-
/*
2273-
* Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta
2274-
*/
22752275
@Override
22762276
public RelNode visitAD(AD node, CalcitePlanContext context) {
2277-
throw new CalciteUnsupportedException("AD command is unsupported in Calcite");
2277+
visitChildren(node, context);
2278+
2279+
RelNode child = context.relBuilder.build();
2280+
RelOptCluster cluster = context.relBuilder.getCluster();
2281+
Map<String, Object> arguments =
2282+
node.getArguments().entrySet().stream()
2283+
.map(entry -> Pair.of(entry.getKey(), entry.getValue().getValue()))
2284+
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));
2285+
2286+
LogicalAD ad =
2287+
new LogicalAD(cluster, cluster.traitSet().replace(Convention.NONE), child, arguments);
2288+
2289+
context.relBuilder.push(ad);
2290+
2291+
return context.relBuilder.peek();
22782292
}
22792293

2294+
/*
2295+
* Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta
2296+
*/
22802297
@Override
22812298
public RelNode visitCloseCursor(CloseCursor closeCursor, CalcitePlanContext context) {
22822299
throw new CalciteUnsupportedException("Close cursor operation is unsupported in Calcite");
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.plan;
7+
8+
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
9+
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
10+
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
11+
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;
12+
13+
import com.google.common.collect.ImmutableMap;
14+
import java.util.List;
15+
import java.util.Map;
16+
import lombok.Getter;
17+
import org.apache.calcite.plan.RelOptCluster;
18+
import org.apache.calcite.plan.RelOptCost;
19+
import org.apache.calcite.plan.RelOptPlanner;
20+
import org.apache.calcite.plan.RelTraitSet;
21+
import org.apache.calcite.rel.RelNode;
22+
import org.apache.calcite.rel.RelShuttle;
23+
import org.apache.calcite.rel.RelWriter;
24+
import org.apache.calcite.rel.SingleRel;
25+
import org.apache.calcite.rel.metadata.RelMetadataQuery;
26+
import org.apache.calcite.rel.type.RelDataType;
27+
import org.apache.calcite.rel.type.RelDataTypeFactory;
28+
import org.apache.calcite.sql.type.SqlTypeName;
29+
30+
public abstract class AD extends SingleRel {
31+
32+
@Getter private final ImmutableMap<String, Object> arguments;
33+
private final boolean isTimeSeries;
34+
35+
private static final String DUPLICATE_RCF_SCORE = RCF_SCORE + "1";
36+
private static final String DUPLICATE_RCF_ANOMALY_GRADE = RCF_ANOMALY_GRADE + "1";
37+
private static final String DUPLICATE_RCF_ANOMALOUS = RCF_ANOMALOUS + "1";
38+
39+
/**
40+
* Creates an AD operator
41+
*
42+
* @param cluster Cluster this relational expression belongs to
43+
* @param traits collation traits of the operator, usually NONE for ad
44+
* @param input Input relational expression
45+
* @param arguments an argument mapping of parameter keys and values
46+
*/
47+
protected AD(
48+
RelOptCluster cluster, RelTraitSet traits, RelNode input, Map<String, Object> arguments) {
49+
super(cluster, traits, input);
50+
this.arguments = ImmutableMap.copyOf(arguments);
51+
this.isTimeSeries = arguments.containsKey(TIME_FIELD);
52+
}
53+
54+
@Override
55+
public RelNode accept(RelShuttle shuttle) {
56+
return copy(traitSet, getInputs());
57+
}
58+
59+
@Override
60+
public RelWriter explainTerms(RelWriter pw) {
61+
return super.explainTerms(pw).item("arguments", arguments);
62+
}
63+
64+
@Override
65+
public final RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
66+
return copy(traitSet, sole(inputs));
67+
}
68+
69+
public abstract AD copy(RelTraitSet traitSet, RelNode input);
70+
71+
@Override
72+
protected RelDataType deriveRowType() {
73+
RelDataType inputRowType = getInput().getRowType();
74+
RelDataTypeFactory typeFactory = getCluster().getTypeFactory();
75+
String scoreName =
76+
inputRowType.getFieldNames().contains(RCF_SCORE) ? DUPLICATE_RCF_SCORE : RCF_SCORE;
77+
if (isTimeSeries) {
78+
String anomalyGradeName =
79+
inputRowType.getFieldNames().contains(RCF_ANOMALY_GRADE)
80+
? DUPLICATE_RCF_ANOMALY_GRADE
81+
: RCF_ANOMALY_GRADE;
82+
return typeFactory
83+
.builder()
84+
.kind(inputRowType.getStructKind())
85+
.addAll(inputRowType.getFieldList())
86+
.add(scoreName, SqlTypeName.DOUBLE)
87+
.add(anomalyGradeName, SqlTypeName.DOUBLE)
88+
.build();
89+
} else {
90+
String anomalousName =
91+
inputRowType.getFieldNames().contains(RCF_ANOMALOUS)
92+
? DUPLICATE_RCF_ANOMALOUS
93+
: RCF_ANOMALOUS;
94+
return typeFactory
95+
.builder()
96+
.kind(inputRowType.getStructKind())
97+
.addAll(inputRowType.getFieldList())
98+
.add(scoreName, SqlTypeName.DOUBLE)
99+
.add(anomalousName, SqlTypeName.BOOLEAN)
100+
.build();
101+
}
102+
}
103+
104+
@Override
105+
public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
106+
double dRows = mq.getRowCount(getInput());
107+
double dCpu = 0; // Assume it's remote cluster AD algorithm cost
108+
double dIo = dRows * 2; // nodeClient request round trip network IO cost
109+
return planner.getCostFactory().makeCost(dRows, dCpu, dIo);
110+
}
111+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.plan;
7+
8+
import java.util.Map;
9+
import org.apache.calcite.plan.RelOptCluster;
10+
import org.apache.calcite.plan.RelTraitSet;
11+
import org.apache.calcite.rel.RelNode;
12+
13+
public class LogicalAD extends AD {
14+
15+
/**
16+
* Creates a LogicalAD operator
17+
*
18+
* @param cluster Cluster this relational expression belongs to
19+
* @param traits collation traits of the operator, usually NONE for ad
20+
* @param input Input relational expression
21+
* @param arguments an argument mapping of parameter keys and values
22+
*/
23+
public LogicalAD(
24+
RelOptCluster cluster, RelTraitSet traits, RelNode input, Map<String, Object> arguments) {
25+
super(cluster, traits, input, arguments);
26+
}
27+
28+
@Override
29+
public final LogicalAD copy(RelTraitSet traitSet, RelNode input) {
30+
return new LogicalAD(getCluster(), traitSet, input, getArguments());
31+
}
32+
}

docs/user/ppl/cmd/ad.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ad (deprecated by ml command)
1212
Description
1313
===========
1414
| The ``ad`` command applies Random Cut Forest (RCF) algorithm in the ml-commons plugin on the search result returned by a PPL command. Based on the input, the command uses two types of RCF algorithms: fixed in time RCF for processing time-series data, batch RCF for processing non-time-series data.
15+
| The command accepts columns containing optional time series field, optional category field, and the other data fields must be double values.
1516
1617

1718
Syntax
@@ -109,4 +110,4 @@ PPL query::
109110

110111
Limitations
111112
===========
112-
The ``ad`` command can only work with ``plugins.calcite.enabled=false``.
113+
The ``ad`` command can only process double type values.

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT;
99
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
1010
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK_WITH_NULL_VALUES;
11+
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_DATATYPE_NUMERIC;
1112
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_LOGS;
1213
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NESTED_SIMPLE;
1314
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_OTEL_LOGS;
@@ -42,6 +43,7 @@ public void init() throws Exception {
4243
loadIndex(Index.WORKER);
4344
loadIndex(Index.WORK_INFORMATION);
4445
loadIndex(Index.WEBLOG);
46+
loadIndex(Index.DATA_TYPE_NUMERIC);
4547
}
4648

4749
@Override
@@ -1788,4 +1790,38 @@ public void testInternalItemAccessOnStructs() throws IOException {
17881790
+ " info.dummy_sub_field",
17891791
TEST_INDEX_WEBLOGS)));
17901792
}
1793+
1794+
1795+
@Test
1796+
public void testAD() throws IOException {
1797+
String expected = loadExpectedPlan("ad.yaml");
1798+
assertYamlEqualsIgnoreId(
1799+
expected,
1800+
explainQueryYaml(
1801+
String.format("source=%s | fields double_number | ad", TEST_INDEX_DATATYPE_NUMERIC)));
1802+
}
1803+
1804+
@Test
1805+
public void testADWithCategory() throws IOException {
1806+
String expected = loadExpectedPlan("ad_category.yaml");
1807+
assertYamlEqualsIgnoreId(
1808+
expected,
1809+
explainQueryYaml(
1810+
String.format(
1811+
"source=%s | stats max(double_number) as max by integer_number | fields"
1812+
+ " integer_number, max | ad category_field='integer_number'",
1813+
TEST_INDEX_DATATYPE_NUMERIC)));
1814+
}
1815+
1816+
@Test
1817+
public void testADWithTimeSeries() throws IOException {
1818+
String expected = loadExpectedPlan("ad_time_series.yaml");
1819+
assertYamlEqualsIgnoreId(
1820+
expected,
1821+
explainQueryYaml(
1822+
String.format(
1823+
"source=%s | where integer_number < 10 | fields long_number, integer_number,"
1824+
+ " double_number | ad category_field='integer_number' time_field='timestamp'",
1825+
TEST_INDEX_DATATYPE_NUMERIC)));
1826+
}
17911827
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalAD(arguments=[{}])
5+
LogicalProject(double_number=[$0])
6+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
7+
physical: |
8+
EnumerableLimit(fetch=[10000])
9+
EnumerableAD(arguments=[{}])
10+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]], PushDownContext=[[PROJECT->[double_number]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["double_number"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalAD(arguments=[{category_field=integer_number}])
5+
LogicalAggregate(group=[{0}], max=[MAX($1)])
6+
LogicalProject(integer_number=[$2], double_number=[$0])
7+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
8+
physical: |
9+
EnumerableLimit(fetch=[10000])
10+
EnumerableAD(arguments=[{category_field=integer_number}])
11+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},max=MAX($1))], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"integer_number":{"terms":{"field":"integer_number","missing_bucket":true,"missing_order":"first","order":"asc"}}}]},"aggregations":{"max":{"max":{"field":"double_number"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalAD(arguments=[{time_field=timestamp, category_field=integer_number}])
5+
LogicalProject(long_number=[$1], integer_number=[$2], double_number=[$0])
6+
LogicalFilter(condition=[<($2, 10)])
7+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
8+
physical: |
9+
EnumerableLimit(fetch=[10000])
10+
EnumerableAD(arguments=[{time_field=timestamp, category_field=integer_number}])
11+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]], PushDownContext=[[PROJECT->[double_number, long_number, integer_number], FILTER-><($2, 10), PROJECT->[long_number, integer_number, double_number]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","query":{"range":{"integer_number":{"from":null,"to":10,"include_lower":true,"include_upper":false,"boost":1.0}}},"_source":{"includes":["long_number","integer_number","double_number"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalAD(arguments=[{}])
5+
LogicalProject(double_number=[$0])
6+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
7+
physical: |
8+
EnumerableLimit(fetch=[10000])
9+
EnumerableAD(arguments=[{}])
10+
EnumerableCalc(expr#0..13=[{inputs}], double_number=[$t0])
11+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
calcite:
2+
logical: |
3+
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
4+
LogicalAD(arguments=[{category_field=integer_number}])
5+
LogicalAggregate(group=[{0}], max=[MAX($1)])
6+
LogicalProject(integer_number=[$2], double_number=[$0])
7+
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
8+
physical: |
9+
EnumerableLimit(fetch=[10000])
10+
EnumerableAD(arguments=[{category_field=integer_number}])
11+
EnumerableAggregate(group=[{2}], max=[MAX($0)])
12+
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])

0 commit comments

Comments
 (0)