Skip to content

Commit bc8d2e4

Browse files
ML command supports category_field parameter (#3909) (#5022)
(cherry picked from commit 661cb8d) Signed-off-by: Binlong Gao <gbinlong@amazon.com> Signed-off-by: Lantao Jin <ltjin@amazon.com> Co-authored-by: gaobinlong <gbinlong@amazon.com>
1 parent a2fe0be commit bc8d2e4

2 files changed

Lines changed: 60 additions & 17 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLOperator.java

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55

66
package org.opensearch.sql.opensearch.planner.physical;
77

8-
import java.util.ArrayList;
8+
import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD;
9+
910
import java.util.Collections;
1011
import java.util.HashMap;
1112
import java.util.Iterator;
1213
import java.util.List;
1314
import java.util.Map;
15+
import java.util.stream.Collectors;
16+
1417
import lombok.EqualsAndHashCode;
1518
import lombok.Getter;
1619
import lombok.RequiredArgsConstructor;
20+
import org.apache.commons.lang3.tuple.Pair;
1721
import org.opensearch.client.node.NodeClient;
1822
import org.opensearch.ml.common.dataframe.DataFrame;
1923
import org.opensearch.ml.common.dataframe.Row;
@@ -42,28 +46,40 @@ public class MLOperator extends MLCommonsOperatorActions {
4246
@Override
4347
public void open() {
4448
super.open();
45-
DataFrame inputDataFrame = generateInputDataset(input);
4649
Map<String, Object> args = processArgs(arguments);
4750

48-
MLOutput mlOutput = getMLOutput(inputDataFrame, args, nodeClient);
49-
final Iterator<Row> inputRowIter = inputDataFrame.iterator();
51+
// Check if category_field is provided
52+
String categoryField =
53+
arguments.containsKey(CATEGORY_FIELD)
54+
? (String) arguments.get(CATEGORY_FIELD).getValue()
55+
: null;
56+
5057
// Only need to check train here, as action should be already checked in ml client.
5158
final boolean isPrediction = ((String) args.get("action")).equals("train") ? false : true;
52-
// For train, only one row to return.
53-
final Iterator<String> trainIter =
54-
new ArrayList<String>() {
55-
{
56-
add("train");
57-
}
58-
}.iterator();
59-
final Iterator<Row> resultRowIter =
60-
isPrediction ? ((MLPredictionOutput) mlOutput).getPredictionResult().iterator() : null;
59+
final Iterator<String> trainIter = Collections.singletonList("train").iterator();
60+
61+
// For prediction mode, handle both categorized and non-categorized cases
62+
List<Pair<DataFrame, DataFrame>> inputDataFrames =
63+
generateCategorizedInputDataset(input, categoryField);
64+
List<MLOutput> mlOutputs =
65+
inputDataFrames.stream()
66+
.map(pair -> getMLOutput(pair.getRight(), args, nodeClient))
67+
.collect(Collectors.toList());
68+
Iterator<Pair<DataFrame, DataFrame>> inputDataFramesIter = inputDataFrames.iterator();
69+
Iterator<MLOutput> mlOutputIter = mlOutputs.iterator();
70+
6171
iterator =
62-
new Iterator<ExprValue>() {
72+
new Iterator<>() {
73+
private DataFrame inputDataFrame = null;
74+
private Iterator<Row> inputRowIter = null;
75+
private MLOutput mlOutput = null;
76+
private Iterator<Row> resultRowIter = null;
77+
6378
@Override
6479
public boolean hasNext() {
6580
if (isPrediction) {
66-
return inputRowIter.hasNext();
81+
return (inputRowIter != null && inputRowIter.hasNext())
82+
|| inputDataFramesIter.hasNext();
6783
} else {
6884
boolean res = trainIter.hasNext();
6985
if (res) {
@@ -75,8 +91,19 @@ public boolean hasNext() {
7591

7692
@Override
7793
public ExprValue next() {
78-
return buildPPLResult(
79-
isPrediction, inputRowIter, inputDataFrame, mlOutput, resultRowIter);
94+
if (isPrediction) {
95+
if (inputRowIter == null || !inputRowIter.hasNext()) {
96+
Pair<DataFrame, DataFrame> pair = inputDataFramesIter.next();
97+
inputDataFrame = pair.getLeft();
98+
inputRowIter = inputDataFrame.iterator();
99+
mlOutput = mlOutputIter.next();
100+
resultRowIter = ((MLPredictionOutput) mlOutput).getPredictionResult().iterator();
101+
}
102+
return buildPPLResult(true, inputRowIter, inputDataFrame, mlOutput, resultRowIter);
103+
} else {
104+
// train case
105+
return buildPPLResult(false, null, null, mlOutputs.get(0), null);
106+
}
80107
}
81108
};
82109
}

opensearch/src/test/java/org/opensearch/sql/opensearch/planner/physical/MLOperatorTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import static org.mockito.Mockito.when;
1616
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
1717
import static org.opensearch.sql.utils.MLCommonsConstants.ALGO;
18+
import static org.opensearch.sql.utils.MLCommonsConstants.CATEGORY_FIELD;
1819
import static org.opensearch.sql.utils.MLCommonsConstants.KMEANS;
1920
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
2021
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
@@ -144,6 +145,21 @@ public void testOpenPredict() {
144145
}
145146
}
146147

148+
@Test
149+
public void testOpenPredictWithCategoryField() {
150+
setUpPredict();
151+
// Add category_field parameter
152+
arguments.put(CATEGORY_FIELD, AstDSL.stringLiteral("region"));
153+
154+
try (MockedStatic<MLClient> mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) {
155+
when(MLClient.getMLClient(any(NodeClient.class))).thenReturn(machineLearningNodeClient);
156+
mlOperator.open();
157+
assertTrue(mlOperator.hasNext());
158+
assertNotNull(mlOperator.next());
159+
assertFalse(mlOperator.hasNext());
160+
}
161+
}
162+
147163
@Test
148164
public void testOpenTrain() {
149165
setUpTrain();

0 commit comments

Comments
 (0)