Skip to content

Commit fd518ff

Browse files
committed
fix(agent): filter AIMessage state updates from streaming output
LangGraph's stream_mode="messages" emits both AIMessageChunk (incremental tokens) and AIMessage (final state update) from the agent node. The _stream_fn was accepting both via isinstance(msg, (AIMessage, AIMessageChunk)), causing the full accumulated response to be emitted as a final chunk after all the individual tokens had already been streamed. Clients saw the complete response duplicated at the end of the SSE stream. Filter to only AIMessageChunk so the state update is excluded. Adds a regression test that confirms AIMessage objects are emitted by the graph stream (the duplicate source) and that filtering to AIMessageChunk excludes them. Signed-off-by: Myles Shannon <mshannon@nvidia.com>
1 parent ee7ab31 commit fd518ff

2 files changed

Lines changed: 41 additions & 2 deletions

File tree

packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/tool_calling_agent/register.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ class ToolCallAgentWorkflowConfig(AgentBaseConfig, name="tool_calling_agent"):
101101

102102
@register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
103103
async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, builder: Builder):
104-
from langchain_core.messages import AIMessage
105104
from langchain_core.messages import AIMessageChunk
106105
from langchain_core.messages import trim_messages
107106
from langchain_core.messages.base import BaseMessage
@@ -219,7 +218,7 @@ async def _stream_fn(chat_request_or_message: ChatRequestOrMessage) -> AsyncGene
219218
state,
220219
config={'recursion_limit': (config.max_iterations + 1) * 2},
221220
stream_mode="messages"):
222-
if not isinstance(msg, (AIMessage, AIMessageChunk)):
221+
if not isinstance(msg, AIMessageChunk):
223222
continue
224223
if metadata.get("langgraph_node") != "agent":
225224
continue

packages/nvidia_nat_langchain/tests/agent/test_tool_calling.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,46 @@ async def test_graph_astream_yields_message_chunks(mock_tool_graph):
367367
assert len(combined_content) > 0, "Expected non-empty content from streamed agent messages"
368368

369369

370+
async def test_stream_fn_no_duplicate_content(mock_tool_graph):
371+
"""Regression: streaming must not duplicate the previous assistant message as a final chunk.
372+
373+
When stream=true, _stream_fn uses graph.astream(stream_mode="messages") which emits
374+
both AIMessageChunk (incremental tokens) and AIMessage (state update). Accepting
375+
AIMessage causes the accumulated response to appear twice in the output. The fix
376+
filters to AIMessageChunk only. This test exercises the same graph.astream path and
377+
asserts that the filtering logic in _stream_fn would prevent duplicates.
378+
"""
379+
from langchain_core.messages import AIMessageChunk
380+
381+
prior_reply = "Hi there!"
382+
mock_state = ToolCallAgentGraphState(messages=[
383+
HumanMessage(content="hello"),
384+
AIMessage(content=prior_reply),
385+
HumanMessage(content="what can you do?"),
386+
])
387+
388+
chunk_contents = []
389+
full_contents = []
390+
async for msg, metadata in mock_tool_graph.astream(
391+
mock_state, config={"recursion_limit": 5}, stream_mode="messages"):
392+
if metadata.get("langgraph_node") != "agent":
393+
continue
394+
if isinstance(msg, AIMessageChunk) and isinstance(msg.content, str) and msg.content:
395+
chunk_contents.append(msg.content)
396+
if hasattr(msg, "content") and isinstance(msg.content, str) and msg.content:
397+
full_contents.append(msg.content)
398+
399+
chunk_response = "".join(chunk_contents)
400+
full_response = "".join(full_contents)
401+
402+
assert prior_reply in full_response, (
403+
"AIMessage state update with prior reply should appear in unfiltered stream"
404+
)
405+
assert prior_reply not in chunk_response, (
406+
f"AIMessageChunk-only stream must not contain prior assistant reply: {chunk_response!r}"
407+
)
408+
409+
370410
def test_tool_call_chunk_serialization():
371411
"""Test that ChatResponseChunk with tool_calls in ChoiceDelta serializes to OpenAI-compatible SSE format."""
372412
chunk = ChatResponseChunk(

0 commit comments

Comments
 (0)