diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java index e169abf0e5..2fc3799e5b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java @@ -5,8 +5,6 @@ package org.opensearch.ml.engine.algorithms.agent; -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.util.ArrayList; @@ -20,7 +18,6 @@ import org.opensearch.ml.common.agui.AGUIInputConverter; import org.opensearch.ml.common.agui.BaseEvent; import org.opensearch.ml.common.agui.MessagesSnapshotEvent; -import org.opensearch.ml.common.agui.RunFinishedEvent; import org.opensearch.ml.common.agui.ToolCallResultEvent; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.execute.agent.Message; @@ -164,25 +161,9 @@ public void sendBackendToolResult(String toolCallId, String toolResult, String s } public void sendRunFinishedAndCloseStream(String sessionId, String parentInteractionId) { - try { - String threadId = parameters.get(AGUI_PARAM_THREAD_ID); - String runId = parameters.get(AGUI_PARAM_RUN_ID); - - // Ensure non-null values to avoid NPE in RunFinishedEvent.writeTo() - if (threadId == null) { - log.warn("AG-UI threadId is null, using generated value. This may cause frontend errors."); - threadId = "thread_" + System.nanoTime(); - } - if (runId == null) { - log.warn("AG-UI runId is null, using generated value. This may cause frontend errors."); - runId = "run_" + System.nanoTime(); - } - - BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null); - sendAGUIEvent(runFinishedEvent, true); - } catch (Exception e) { - log.error("Failed to send run finished event and close stream", e); - } + // Send an empty completion chunk with is_last=true. + // RestMLExecuteStreamAction will emit RUN_FINISHED when it sees the final chunk. + sendCompletionChunk(sessionId, parentInteractionId); } public void sendMessagesSnapshot(List history, String memoryId, ActionListener listener) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java index 7e54f99322..3ec2bb11cb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java @@ -6,9 +6,7 @@ package org.opensearch.ml.engine.algorithms.remote.streaming; import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGE_ID; -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID; import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TEXT_MESSAGE_STARTED; -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; @@ -21,7 +19,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.ml.common.agui.BaseEvent; -import org.opensearch.ml.common.agui.RunFinishedEvent; import org.opensearch.ml.common.agui.TextMessageContentEvent; import org.opensearch.ml.common.agui.TextMessageEndEvent; import org.opensearch.ml.common.agui.TextMessageStartEvent; @@ -244,21 +241,12 @@ private void handleDoneEvent() { && "true".equalsIgnoreCase(parameters.get(AGUI_PARAM_TEXT_MESSAGE_STARTED)); if (textMessageStarted) { - // End any remaining text message parameters.put(AGUI_PARAM_TEXT_MESSAGE_STARTED, "false"); BaseEvent textMessageEndEvent = new TextMessageEndEvent(messageId); sendAGUIEvent(textMessageEndEvent, false, streamActionListener); log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {} at stream end", messageId); } - // Send RUN_FINISHED event - String threadId = parameters.get(AGUI_PARAM_THREAD_ID); - String runId = parameters.get(AGUI_PARAM_RUN_ID); - BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null); - sendAGUIEvent(runFinishedEvent, false, streamActionListener); - log.debug("AG-UI: Sent RUN_FINISHED event at [DONE] - threadId={}, runId={}", threadId, runId); - - // Trigger agentListener callback to save assistant structured message streamActionListener.onResponse(createFinalAnswerResponse(accumulatedContent.toString())); } @@ -277,21 +265,12 @@ private void processStreamChunk(Map dataMap) { if (isAGUIAgent) { if (textMessageStarted) { - // End the current text message parameters.put(AGUI_PARAM_TEXT_MESSAGE_STARTED, "false"); BaseEvent textMessageEndEvent = new TextMessageEndEvent(messageId); sendAGUIEvent(textMessageEndEvent, false, streamActionListener); log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {}", messageId); } - // Send RUN_FINISHED event - String threadId = parameters.get(AGUI_PARAM_THREAD_ID); - String runId = parameters.get(AGUI_PARAM_RUN_ID); - BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null); - sendAGUIEvent(runFinishedEvent, false, streamActionListener); - log.debug("AG-UI: Sent RUN_FINISHED event - threadId={}, runId={}", threadId, runId); - - // Trigger agentListener callback to save assistant structured message streamActionListener.onResponse(createFinalAnswerResponse(accumulatedContent.toString())); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java index ee262aba3d..b3a50c208c 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java @@ -486,11 +486,12 @@ private HttpChunk convertToHttpChunk(MLTaskResponse response, boolean isAGUIAgen } // Forward any content events (pass false — isLast is controlled by the outer chunk) - HttpChunk contentChunk = convertToAGUIEvent(content, false); - combinedSse.append(new String(BytesReference.toBytes(contentChunk.content()))); + AGUIEventResult contentResult = convertToAGUIEvent(content, false); + combinedSse.append(new String(BytesReference.toBytes(contentResult.chunk().content()))); - // RunFinished is the last AG-UI event, emitted only on the final chunk - if (isLast) { + // RunFinished is the last AG-UI event, emitted only on the final chunk. + // Skip if a RUN_ERROR was already emitted — RUN_ERROR is a terminal event. + if (isLast && !contentResult.hasRunError()) { BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null); combinedSse.append("data: ").append(runFinishedEvent.toJsonString()).append("\n\n"); } @@ -571,7 +572,10 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) { return null; } - private HttpChunk convertToAGUIEvent(String content, boolean isLast) { + private record AGUIEventResult(HttpChunk chunk, boolean hasRunError) { + } + + private AGUIEventResult convertToAGUIEvent(String content, boolean isLast) { log .debug( "RestMLExecuteStreamAction: convertToAGUIEvent() called - contentLength={}, isLast={}", @@ -580,6 +584,7 @@ private HttpChunk convertToAGUIEvent(String content, boolean isLast) { ); StringBuilder sseResponse = new StringBuilder(); + boolean hasRunError = false; if (content != null && !content.isEmpty()) { log.debug("RestMLExecuteStreamAction: Processing content: '{}'", content); @@ -590,17 +595,18 @@ private HttpChunk convertToAGUIEvent(String content, boolean isLast) { sseResponse.append("data: ").append(element).append("\n\n"); log.debug("RestMLExecuteStreamAction: Processing json element: '{}'", element); } else { - // catch unexpected content chunks such as Bedrock error log.warn("Unexpected content received - not valid JSON: {}", content); BaseEvent runErrorEvent = new RunErrorEvent("Unexpected chunk: " + content, null); sseResponse.append("data: ").append(runErrorEvent.toJsonString()).append("\n\n"); isLast = true; + hasRunError = true; } } catch (Exception e) { log.error("Failed to process AG-UI events chunk content {}", content, e); BaseEvent runErrorEvent = new RunErrorEvent("Unexpected error: " + e.getMessage(), null); sseResponse.append("data: ").append(runErrorEvent.toJsonString()).append("\n\n"); isLast = true; + hasRunError = true; } } else { log.warn("Received null or empty AG-UI content chunk"); @@ -608,7 +614,7 @@ private HttpChunk convertToAGUIEvent(String content, boolean isLast) { String finalSse = sseResponse.toString(); log.debug("RestMLExecuteStreamAction: Returning chunk - length={}", finalSse.length()); - return createHttpChunk(finalSse, isLast); + return new AGUIEventResult(createHttpChunk(finalSse, isLast), hasRunError); } @VisibleForTesting