Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import java.util.stream.Stream;
import lombok.AllArgsConstructor;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.ViewExpanders;
import org.apache.calcite.rel.RelNode;
Expand Down Expand Up @@ -144,6 +146,7 @@
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;
import org.opensearch.sql.calcite.plan.AliasFieldsWrappable;
import org.opensearch.sql.calcite.plan.LogicalAD;
import org.opensearch.sql.calcite.plan.LogicalSystemLimit;
import org.opensearch.sql.calcite.plan.LogicalSystemLimit.SystemLimitType;
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
Expand Down Expand Up @@ -2290,14 +2293,28 @@ private String findTimestampField(RelDataType rowType) {
return null;
}

/*
* Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta
*/
@Override
public RelNode visitAD(AD node, CalcitePlanContext context) {
throw new CalciteUnsupportedException("AD command is unsupported in Calcite");
visitChildren(node, context);

RelNode child = context.relBuilder.build();
RelOptCluster cluster = context.relBuilder.getCluster();
Map<String, Object> arguments =
node.getArguments().entrySet().stream()
.map(entry -> Pair.of(entry.getKey(), entry.getValue().getValue()))
.collect(Collectors.toMap(Pair::getKey, Pair::getValue));

LogicalAD ad =
new LogicalAD(cluster, cluster.traitSet().replace(Convention.NONE), child, arguments);

context.relBuilder.push(ad);

return context.relBuilder.peek();
}

/*
* Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta
*/
@Override
public RelNode visitCloseCursor(CloseCursor closeCursor, CalcitePlanContext context) {
throw new CalciteUnsupportedException("Close cursor operation is unsupported in Calcite");
Expand Down
111 changes: 111 additions & 0 deletions core/src/main/java/org/opensearch/sql/calcite/plan/AD.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.plan;

import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;

import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import lombok.Getter;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.type.SqlTypeName;

public abstract class AD extends SingleRel {

@Getter private final ImmutableMap<String, Object> arguments;
private final boolean isTimeSeries;

private static final String DUPLICATE_RCF_SCORE = RCF_SCORE + "1";
private static final String DUPLICATE_RCF_ANOMALY_GRADE = RCF_ANOMALY_GRADE + "1";
private static final String DUPLICATE_RCF_ANOMALOUS = RCF_ANOMALOUS + "1";

/**
* Creates an AD operator
*
* @param cluster Cluster this relational expression belongs to
* @param traits collation traits of the operator, usually NONE for ad
* @param input Input relational expression
* @param arguments an argument mapping of parameter keys and values
*/
protected AD(
RelOptCluster cluster, RelTraitSet traits, RelNode input, Map<String, Object> arguments) {
super(cluster, traits, input);
this.arguments = ImmutableMap.copyOf(arguments);
this.isTimeSeries = arguments.containsKey(TIME_FIELD);
}

@Override
public RelNode accept(RelShuttle shuttle) {
return copy(traitSet, getInputs());
}

@Override
public RelWriter explainTerms(RelWriter pw) {
return super.explainTerms(pw).item("arguments", arguments);
}

@Override
public final RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
return copy(traitSet, sole(inputs));
}

public abstract AD copy(RelTraitSet traitSet, RelNode input);

@Override
protected RelDataType deriveRowType() {
RelDataType inputRowType = getInput().getRowType();
RelDataTypeFactory typeFactory = getCluster().getTypeFactory();
String scoreName =
inputRowType.getFieldNames().contains(RCF_SCORE) ? DUPLICATE_RCF_SCORE : RCF_SCORE;
if (isTimeSeries) {
String anomalyGradeName =
inputRowType.getFieldNames().contains(RCF_ANOMALY_GRADE)
? DUPLICATE_RCF_ANOMALY_GRADE
: RCF_ANOMALY_GRADE;
return typeFactory
.builder()
.kind(inputRowType.getStructKind())
.addAll(inputRowType.getFieldList())
.add(scoreName, SqlTypeName.DOUBLE)
.add(anomalyGradeName, SqlTypeName.DOUBLE)
.build();
} else {
String anomalousName =
inputRowType.getFieldNames().contains(RCF_ANOMALOUS)
? DUPLICATE_RCF_ANOMALOUS
: RCF_ANOMALOUS;
return typeFactory
.builder()
.kind(inputRowType.getStructKind())
.addAll(inputRowType.getFieldList())
.add(scoreName, SqlTypeName.DOUBLE)
.add(anomalousName, SqlTypeName.BOOLEAN)
.build();
}
}

@Override
public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
double dRows = mq.getRowCount(getInput());
double dCpu = 0; // Assume it's remote cluster AD algorithm cost
double dIo = dRows * 2; // nodeClient request round trip network IO cost
return planner.getCostFactory().makeCost(dRows, dCpu, dIo);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.plan;

import java.util.Map;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;

public class LogicalAD extends AD {

/**
* Creates a LogicalAD operator
*
* @param cluster Cluster this relational expression belongs to
* @param traits collation traits of the operator, usually NONE for ad
* @param input Input relational expression
* @param arguments an argument mapping of parameter keys and values
*/
public LogicalAD(
RelOptCluster cluster, RelTraitSet traits, RelNode input, Map<String, Object> arguments) {
super(cluster, traits, input, arguments);
}

@Override
public final LogicalAD copy(RelTraitSet traitSet, RelNode input) {
return new LogicalAD(getCluster(), traitSet, input, getArguments());
}
}
3 changes: 2 additions & 1 deletion docs/user/ppl/cmd/ad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ad (deprecated by ml command)
Description
===========
| The ``ad`` command applies Random Cut Forest (RCF) algorithm in the ml-commons plugin on the search result returned by a PPL command. Based on the input, the command uses two types of RCF algorithms: fixed in time RCF for processing time-series data, batch RCF for processing non-time-series data.
| The command accepts columns containing optional time series field, optional category field, and the other data fields must be double values.


Syntax
Expand Down Expand Up @@ -109,4 +110,4 @@ PPL query::

Limitations
===========
The ``ad`` command can only work with ``plugins.calcite.enabled=false``.
The ``ad`` command can only process double type values.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ALIAS;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK_WITH_NULL_VALUES;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_DATATYPE_NUMERIC;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_LOGS;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NESTED_SIMPLE;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_OTEL_LOGS;
Expand Down Expand Up @@ -44,6 +45,7 @@ public void init() throws Exception {
loadIndex(Index.WORK_INFORMATION);
loadIndex(Index.WEBLOG);
loadIndex(Index.DATA_TYPE_ALIAS);
loadIndex(Index.DATA_TYPE_NUMERIC);
}

@Override
Expand Down Expand Up @@ -1966,4 +1968,37 @@ public void testAliasTypeField() throws IOException {
"source=%s | fields alias_col | where alias_col > 10 | stats avg(alias_col)",
TEST_INDEX_ALIAS)));
}

@Test
public void testAD() throws IOException {
String expected = loadExpectedPlan("ad.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
String.format("source=%s | fields double_number | ad", TEST_INDEX_DATATYPE_NUMERIC)));
}

@Test
public void testADWithCategory() throws IOException {
String expected = loadExpectedPlan("ad_category.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
String.format(
"source=%s | stats max(double_number) as max by integer_number | fields"
+ " integer_number, max | ad category_field='integer_number'",
TEST_INDEX_DATATYPE_NUMERIC)));
}

@Test
public void testADWithTimeSeries() throws IOException {
String expected = loadExpectedPlan("ad_time_series.yaml");
assertYamlEqualsIgnoreId(
expected,
explainQueryYaml(
String.format(
"source=%s | where integer_number < 10 | fields long_number, integer_number,"
+ " double_number | ad category_field='integer_number' time_field='timestamp'",
TEST_INDEX_DATATYPE_NUMERIC)));
}
}
10 changes: 10 additions & 0 deletions integ-test/src/test/resources/expectedOutput/calcite/ad.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalAD(arguments=[{}])
LogicalProject(double_number=[$0])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableAD(arguments=[{}])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]], PushDownContext=[[PROJECT->[double_number]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["double_number"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalAD(arguments=[{category_field=integer_number}])
LogicalAggregate(group=[{0}], max=[MAX($1)])
LogicalProject(integer_number=[$2], double_number=[$0])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableAD(arguments=[{category_field=integer_number}])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},max=MAX($1))], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"integer_number":{"terms":{"field":"integer_number","missing_bucket":true,"missing_order":"first","order":"asc"}}}]},"aggregations":{"max":{"max":{"field":"double_number"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalAD(arguments=[{time_field=timestamp, category_field=integer_number}])
LogicalProject(long_number=[$1], integer_number=[$2], double_number=[$0])
LogicalFilter(condition=[<($2, 10)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableAD(arguments=[{time_field=timestamp, category_field=integer_number}])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]], PushDownContext=[[PROJECT->[double_number, long_number, integer_number], FILTER-><($2, 10), PROJECT->[long_number, integer_number, double_number]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","query":{"range":{"integer_number":{"from":null,"to":10,"include_lower":true,"include_upper":false,"boost":1.0}}},"_source":{"includes":["long_number","integer_number","double_number"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalAD(arguments=[{}])
LogicalProject(double_number=[$0])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableAD(arguments=[{}])
EnumerableCalc(expr#0..13=[{inputs}], double_number=[$t0])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalAD(arguments=[{category_field=integer_number}])
LogicalAggregate(group=[{0}], max=[MAX($1)])
LogicalProject(integer_number=[$2], double_number=[$0])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableAD(arguments=[{category_field=integer_number}])
EnumerableAggregate(group=[{2}], max=[MAX($0)])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalAD(arguments=[{time_field=timestamp, category_field=integer_number}])
LogicalProject(long_number=[$1], integer_number=[$2], double_number=[$0])
LogicalFilter(condition=[<($2, 10)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableAD(arguments=[{time_field=timestamp, category_field=integer_number}])
EnumerableCalc(expr#0..13=[{inputs}], expr#14=[10], expr#15=[<($t2, $t14)], long_number=[$t1], integer_number=[$t2], double_number=[$t0], $condition=[$t15])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_datatypes_numeric]])
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector;
import org.opensearch.sql.opensearch.functions.DistinctCountApproxAggFunction;
import org.opensearch.sql.opensearch.functions.GeoIpFunction;
import org.opensearch.sql.opensearch.storage.NodeClientHolder;
import org.opensearch.sql.opensearch.util.JdbcOpenSearchDataTypeConvertor;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.storage.TableScanOperator;
Expand All @@ -76,6 +77,9 @@ public OpenSearchExecutionEngine(
ExecutionProtector executionProtector,
PlanSerializer planSerializer) {
this.client = client;
if (client.getNodeClient().isPresent()) {
NodeClientHolder.init(client.getNodeClient().get());
}
this.executionProtector = executionProtector;
this.planSerializer = planSerializer;
registerOpenSearchFunctions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ protected ExprTupleValue buildResult(
* @param nodeClient node client
* @return ml-commons train and predict result
*/
protected MLPredictionOutput getMLPredictionResult(
public static MLPredictionOutput getMLPredictionResult(
FunctionName functionName,
MLAlgoParams mlAlgoParams,
DataFrame inputDataFrame,
Expand Down
Loading
Loading