55
66package org .opensearch .sql .opensearch .planner .physical ;
77
8- import java .util .ArrayList ;
8+ import static org .opensearch .sql .utils .MLCommonsConstants .CATEGORY_FIELD ;
9+
910import java .util .Collections ;
1011import java .util .HashMap ;
1112import java .util .Iterator ;
1213import java .util .List ;
1314import java .util .Map ;
15+ import java .util .stream .Collectors ;
16+
1417import lombok .EqualsAndHashCode ;
1518import lombok .Getter ;
1619import lombok .RequiredArgsConstructor ;
20+ import org .apache .commons .lang3 .tuple .Pair ;
1721import org .opensearch .client .node .NodeClient ;
1822import org .opensearch .ml .common .dataframe .DataFrame ;
1923import org .opensearch .ml .common .dataframe .Row ;
@@ -42,28 +46,40 @@ public class MLOperator extends MLCommonsOperatorActions {
4246 @ Override
4347 public void open () {
4448 super .open ();
45- DataFrame inputDataFrame = generateInputDataset (input );
4649 Map <String , Object > args = processArgs (arguments );
4750
48- MLOutput mlOutput = getMLOutput (inputDataFrame , args , nodeClient );
49- final Iterator <Row > inputRowIter = inputDataFrame .iterator ();
51+ // Check if category_field is provided
52+ String categoryField =
53+ arguments .containsKey (CATEGORY_FIELD )
54+ ? (String ) arguments .get (CATEGORY_FIELD ).getValue ()
55+ : null ;
56+
5057 // Only need to check train here, as action should be already checked in ml client.
5158 final boolean isPrediction = ((String ) args .get ("action" )).equals ("train" ) ? false : true ;
52- // For train, only one row to return.
53- final Iterator <String > trainIter =
54- new ArrayList <String >() {
55- {
56- add ("train" );
57- }
58- }.iterator ();
59- final Iterator <Row > resultRowIter =
60- isPrediction ? ((MLPredictionOutput ) mlOutput ).getPredictionResult ().iterator () : null ;
59+ final Iterator <String > trainIter = Collections .singletonList ("train" ).iterator ();
60+
61+ // For prediction mode, handle both categorized and non-categorized cases
62+ List <Pair <DataFrame , DataFrame >> inputDataFrames =
63+ generateCategorizedInputDataset (input , categoryField );
64+ List <MLOutput > mlOutputs =
65+ inputDataFrames .stream ()
66+ .map (pair -> getMLOutput (pair .getRight (), args , nodeClient ))
67+ .collect (Collectors .toList ());
68+ Iterator <Pair <DataFrame , DataFrame >> inputDataFramesIter = inputDataFrames .iterator ();
69+ Iterator <MLOutput > mlOutputIter = mlOutputs .iterator ();
70+
6171 iterator =
62- new Iterator <ExprValue >() {
72+ new Iterator <>() {
73+ private DataFrame inputDataFrame = null ;
74+ private Iterator <Row > inputRowIter = null ;
75+ private MLOutput mlOutput = null ;
76+ private Iterator <Row > resultRowIter = null ;
77+
6378 @ Override
6479 public boolean hasNext () {
6580 if (isPrediction ) {
66- return inputRowIter .hasNext ();
81+ return (inputRowIter != null && inputRowIter .hasNext ())
82+ || inputDataFramesIter .hasNext ();
6783 } else {
6884 boolean res = trainIter .hasNext ();
6985 if (res ) {
@@ -75,8 +91,19 @@ public boolean hasNext() {
7591
7692 @ Override
7793 public ExprValue next () {
78- return buildPPLResult (
79- isPrediction , inputRowIter , inputDataFrame , mlOutput , resultRowIter );
94+ if (isPrediction ) {
95+ if (inputRowIter == null || !inputRowIter .hasNext ()) {
96+ Pair <DataFrame , DataFrame > pair = inputDataFramesIter .next ();
97+ inputDataFrame = pair .getLeft ();
98+ inputRowIter = inputDataFrame .iterator ();
99+ mlOutput = mlOutputIter .next ();
100+ resultRowIter = ((MLPredictionOutput ) mlOutput ).getPredictionResult ().iterator ();
101+ }
102+ return buildPPLResult (true , inputRowIter , inputDataFrame , mlOutput , resultRowIter );
103+ } else {
104+ // train case
105+ return buildPPLResult (false , null , null , mlOutputs .get (0 ), null );
106+ }
80107 }
81108 };
82109 }
0 commit comments