Skip to content

Commit 661cb8d

Browse files
authored
ML command supports category_field parameter (#3909)
Signed-off-by: Binlong Gao <gbinlong@amazon.com>
1 parent 77633ef commit 661cb8d

2 files changed

Lines changed: 58 additions & 17 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
package org.opensearch.sql.opensearch.planner.physical;
77

8-
import java.util.ArrayList;
8+
import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD;
9+
910
import java.util.Collections;
1011
import java.util.HashMap;
1112
import java.util.Iterator;
@@ -14,6 +15,7 @@
1415
import lombok.EqualsAndHashCode;
1516
import lombok.Getter;
1617
import lombok.RequiredArgsConstructor;
18+
import org.apache.commons.lang3.tuple.Pair;
1719
import org.opensearch.ml.common.dataframe.DataFrame;
1820
import org.opensearch.ml.common.dataframe.Row;
1921
import org.opensearch.ml.common.output.MLOutput;
@@ -42,28 +44,40 @@ public class MLOperator extends MLCommonsOperatorActions {
4244
@Override
4345
public void open() {
4446
super.open();
45-
DataFrame inputDataFrame = generateInputDataset(input);
4647
Map<String, Object> args = processArgs(arguments);
4748

48-
MLOutput mlOutput = getMLOutput(inputDataFrame, args, nodeClient);
49-
final Iterator<Row> inputRowIter = inputDataFrame.iterator();
49+
// Check if category_field is provided
50+
String categoryField =
51+
arguments.containsKey(CATEGORY_FIELD)
52+
? (String) arguments.get(CATEGORY_FIELD).getValue()
53+
: null;
54+
5055
// Only need to check train here, as action should be already checked in ml client.
5156
final boolean isPrediction = ((String) args.get("action")).equals("train") ? false : true;
52-
// For train, only one row to return.
53-
final Iterator<String> trainIter =
54-
new ArrayList<String>() {
55-
{
56-
add("train");
57-
}
58-
}.iterator();
59-
final Iterator<Row> resultRowIter =
60-
isPrediction ? ((MLPredictionOutput) mlOutput).getPredictionResult().iterator() : null;
57+
final Iterator<String> trainIter = Collections.singletonList("train").iterator();
58+
59+
// For prediction mode, handle both categorized and non-categorized cases
60+
List<Pair<DataFrame, DataFrame>> inputDataFrames =
61+
generateCategorizedInputDataset(input, categoryField);
62+
List<MLOutput> mlOutputs =
63+
inputDataFrames.stream()
64+
.map(pair -> getMLOutput(pair.getRight(), args, nodeClient))
65+
.toList();
66+
Iterator<Pair<DataFrame, DataFrame>> inputDataFramesIter = inputDataFrames.iterator();
67+
Iterator<MLOutput> mlOutputIter = mlOutputs.iterator();
68+
6169
iterator =
62-
new Iterator<ExprValue>() {
70+
new Iterator<>() {
71+
private DataFrame inputDataFrame = null;
72+
private Iterator<Row> inputRowIter = null;
73+
private MLOutput mlOutput = null;
74+
private Iterator<Row> resultRowIter = null;
75+
6376
@Override
6477
public boolean hasNext() {
6578
if (isPrediction) {
66-
return inputRowIter.hasNext();
79+
return (inputRowIter != null && inputRowIter.hasNext())
80+
|| inputDataFramesIter.hasNext();
6781
} else {
6882
boolean res = trainIter.hasNext();
6983
if (res) {
@@ -75,8 +89,19 @@ public boolean hasNext() {
7589

7690
@Override
7791
public ExprValue next() {
78-
return buildPPLResult(
79-
isPrediction, inputRowIter, inputDataFrame, mlOutput, resultRowIter);
92+
if (isPrediction) {
93+
if (inputRowIter == null || !inputRowIter.hasNext()) {
94+
Pair<DataFrame, DataFrame> pair = inputDataFramesIter.next();
95+
inputDataFrame = pair.getLeft();
96+
inputRowIter = inputDataFrame.iterator();
97+
mlOutput = mlOutputIter.next();
98+
resultRowIter = ((MLPredictionOutput) mlOutput).getPredictionResult().iterator();
99+
}
100+
return buildPPLResult(true, inputRowIter, inputDataFrame, mlOutput, resultRowIter);
101+
} else {
102+
// train case
103+
return buildPPLResult(false, null, null, mlOutputs.getFirst(), null);
104+
}
80105
}
81106
};
82107
}

opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import static org.mockito.Mockito.when;
1616
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
1717
import static org.opensearch.sql.utils.MLCommonsConstants.ALGO;
18+
import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD;
1819
import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS;
1920
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
2021
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
@@ -144,6 +145,21 @@ public void testOpenPredict() {
144145
}
145146
}
146147

148+
@Test
149+
public void testOpenPredictWithCategoryField() {
150+
setUpPredict();
151+
// Add category_field parameter
152+
arguments.put(CATEGORY_FIELD, AstDSL.stringLiteral("region"));
153+
154+
try (MockedStatic<MLClient> mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) {
155+
when(MLClient.getMLClient(any(NodeClient.class))).thenReturn(machineLearningNodeClient);
156+
mlOperator.open();
157+
assertTrue(mlOperator.hasNext());
158+
assertNotNull(mlOperator.next());
159+
assertFalse(mlOperator.hasNext());
160+
}
161+
}
162+
147163
@Test
148164
public void testOpenTrain() {
149165
setUpTrain();

0 commit comments

Comments
 (0)