55
66package org .opensearch .sql .opensearch .planner .physical ;
77
8- import java .util .ArrayList ;
8+ import static org .opensearch .sql .utils .MLCommonsConstants .CATEGORY_FIELD ;
9+
910import java .util .Collections ;
1011import java .util .HashMap ;
1112import java .util .Iterator ;
1415import lombok .EqualsAndHashCode ;
1516import lombok .Getter ;
1617import lombok .RequiredArgsConstructor ;
18+ import org .apache .commons .lang3 .tuple .Pair ;
1719import org .opensearch .ml .common .dataframe .DataFrame ;
1820import org .opensearch .ml .common .dataframe .Row ;
1921import org .opensearch .ml .common .output .MLOutput ;
@@ -42,28 +44,40 @@ public class MLOperator extends MLCommonsOperatorActions {
4244 @ Override
4345 public void open () {
4446 super .open ();
45- DataFrame inputDataFrame = generateInputDataset (input );
4647 Map <String , Object > args = processArgs (arguments );
4748
48- MLOutput mlOutput = getMLOutput (inputDataFrame , args , nodeClient );
49- final Iterator <Row > inputRowIter = inputDataFrame .iterator ();
49+ // Check if category_field is provided
50+ String categoryField =
51+ arguments .containsKey (CATEGORY_FIELD )
52+ ? (String ) arguments .get (CATEGORY_FIELD ).getValue ()
53+ : null ;
54+
5055 // Only need to check train here, as action should be already checked in ml client.
5156 final boolean isPrediction = ((String ) args .get ("action" )).equals ("train" ) ? false : true ;
52- // For train, only one row to return.
53- final Iterator <String > trainIter =
54- new ArrayList <String >() {
55- {
56- add ("train" );
57- }
58- }.iterator ();
59- final Iterator <Row > resultRowIter =
60- isPrediction ? ((MLPredictionOutput ) mlOutput ).getPredictionResult ().iterator () : null ;
57+ final Iterator <String > trainIter = Collections .singletonList ("train" ).iterator ();
58+
59+ // For prediction mode, handle both categorized and non-categorized cases
60+ List <Pair <DataFrame , DataFrame >> inputDataFrames =
61+ generateCategorizedInputDataset (input , categoryField );
62+ List <MLOutput > mlOutputs =
63+ inputDataFrames .stream ()
64+ .map (pair -> getMLOutput (pair .getRight (), args , nodeClient ))
65+ .toList ();
66+ Iterator <Pair <DataFrame , DataFrame >> inputDataFramesIter = inputDataFrames .iterator ();
67+ Iterator <MLOutput > mlOutputIter = mlOutputs .iterator ();
68+
6169 iterator =
62- new Iterator <ExprValue >() {
70+ new Iterator <>() {
71+ private DataFrame inputDataFrame = null ;
72+ private Iterator <Row > inputRowIter = null ;
73+ private MLOutput mlOutput = null ;
74+ private Iterator <Row > resultRowIter = null ;
75+
6376 @ Override
6477 public boolean hasNext () {
6578 if (isPrediction ) {
66- return inputRowIter .hasNext ();
79+ return (inputRowIter != null && inputRowIter .hasNext ())
80+ || inputDataFramesIter .hasNext ();
6781 } else {
6882 boolean res = trainIter .hasNext ();
6983 if (res ) {
@@ -75,8 +89,19 @@ public boolean hasNext() {
7589
7690 @ Override
7791 public ExprValue next () {
78- return buildPPLResult (
79- isPrediction , inputRowIter , inputDataFrame , mlOutput , resultRowIter );
92+ if (isPrediction ) {
93+ if (inputRowIter == null || !inputRowIter .hasNext ()) {
94+ Pair <DataFrame , DataFrame > pair = inputDataFramesIter .next ();
95+ inputDataFrame = pair .getLeft ();
96+ inputRowIter = inputDataFrame .iterator ();
97+ mlOutput = mlOutputIter .next ();
98+ resultRowIter = ((MLPredictionOutput ) mlOutput ).getPredictionResult ().iterator ();
99+ }
100+ return buildPPLResult (true , inputRowIter , inputDataFrame , mlOutput , resultRowIter );
101+ } else {
102+ // train case
103+ return buildPPLResult (false , null , null , mlOutputs .getFirst (), null );
104+ }
80105 }
81106 };
82107 }
0 commit comments