From a76cdffc4bd72e035c5b6fdd11b1b610f5897236 Mon Sep 17 00:00:00 2001 From: Zelys Date: Sat, 4 Apr 2026 18:23:00 -0500 Subject: [PATCH] Fix stream_async() yielding non-serializable invocation_state fields ModelStreamEvent.prepare() merged the full invocation_state dict into the event dict before it was yielded to callers. Because invocation_state can contain non-JSON-serializable objects (Agent instance, OTel spans), those objects appeared in every streamed event, causing repr() serialization on the wire and hitting payload limits in runtimes like AgentCore. The fix snapshots the event data before prepare() is called and yields that snapshot instead of the merged dict. The callback_handler still receives the fully-merged dict for backward compatibility. Fixes #1928 --- src/strands/agent/agent.py | 10 +++- .../strands/agent/hooks/test_agent_events.py | 32 +++++++++-- tests/strands/agent/test_agent.py | 56 +++++++++++++++++-- 3 files changed, 87 insertions(+), 11 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133de..585d9a7d3 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -828,12 +828,16 @@ async def stream_async( events = self._run_loop(messages, merged_state, structured_output_model, structured_output_prompt) async for event in events: + # Snapshot the event data before prepare() merges invocation_state + # into the dict. The callback_handler receives the full merged dict + # for backward compatibility, but stream_async() callers only see + # the serializable event fields. + event_data = event.as_dict() event.prepare(invocation_state=merged_state) if event.is_callback_event: - as_dict = event.as_dict() - callback_handler(**as_dict) - yield as_dict + callback_handler(**event.as_dict()) + yield event_data result = AgentResult(*event["stop"]) callback_handler(result=result) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 02c367ccc..9ab34dea6 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -46,6 +46,24 @@ def mock_sleep(): "request_state": {}, } +# Keys that prepare() merges from invocation_state. stream_async() no longer includes +# these in yielded events; callback_handler still receives them for backward compat. +_INVOCATION_STATE_KEYS = frozenset(any_props.keys()) | frozenset( + { + "event_loop_parent_cycle_id", + "messages", + "model", + "system_prompt", + "tool_config", + } +) + + +def _strip_state(events: list[dict], *user_keys: str) -> list[dict]: + """Return events with invocation_state fields removed (matches what stream_async() yields).""" + keys_to_remove = _INVOCATION_STATE_KEYS | set(user_keys) + return [{k: v for k, v in e.items() if k not in keys_to_remove} for e in events] + @pytest.mark.asyncio async def test_stream_e2e_success(alist): @@ -317,7 +335,10 @@ async def test_stream_e2e_success(alist): ), }, ] - assert tru_events == exp_events + # stream_async() yields events without invocation_state; callback_handler receives + # the full merged dict. Verify both independently. + exp_yield_events = _strip_state(exp_events, "arg1") + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list @@ -381,7 +402,8 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): ), }, ] - assert tru_events == exp_events + exp_yield_events = _strip_state(exp_events, "arg1") + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list @@ -459,7 +481,8 @@ async def test_stream_e2e_reasoning_redacted_content(alist): ), }, ] - assert tru_events == exp_events + exp_yield_events = _strip_state(exp_events) + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list @@ -514,7 +537,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end( {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, ] - assert tru_events == exp_events + exp_yield_events = _strip_state(exp_events, "arg1") + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5a3cce11c..91d821c4f 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1080,8 +1080,11 @@ async def test_event_loop(*args, **kwargs): stream = agent.stream_async("test message", callback_handler=mock_callback) tru_events = await alist(stream) + + # stream_async() yields events without invocation_state merged in; invocation_state + # is only passed to the callback_handler for backward compat. exp_events = [ - {"init_event_loop": True, "callback_handler": mock_callback}, + {"init_event_loop": True}, {"data": "First chunk"}, {"data": "Second chunk"}, {"complete": True, "data": "Final chunk"}, @@ -1096,8 +1099,24 @@ async def test_event_loop(*args, **kwargs): ] assert tru_events == exp_events - exp_calls = [unittest.mock.call(**event) for event in exp_events] - mock_callback.assert_has_calls(exp_calls) + # The callback_handler receives the fully-merged dict (including invocation_state). + exp_callback_calls = [ + unittest.mock.call(**{"init_event_loop": True, "callback_handler": mock_callback}), + unittest.mock.call(**{"data": "First chunk"}), + unittest.mock.call(**{"data": "Second chunk"}), + unittest.mock.call(**{"complete": True, "data": "Final chunk"}), + unittest.mock.call( + **{ + "result": AgentResult( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + state={}, + ) + } + ), + ] + mock_callback.assert_has_calls(exp_callback_calls) @pytest.mark.asyncio @@ -1196,7 +1215,7 @@ async def check_invocation_state(**kwargs): tru_events = await alist(stream) exp_events = [ - {"init_event_loop": True, "some_value": "a_value"}, + {"init_event_loop": True}, { "result": AgentResult( stop_reason="stop", @@ -1211,6 +1230,35 @@ async def check_invocation_state(**kwargs): assert mock_event_loop_cycle.call_count == 1 +@pytest.mark.asyncio +async def test_stream_async_does_not_yield_invocation_state(mock_event_loop_cycle, alist): + """stream_async() must not include invocation_state in yielded events. + + Non-serializable objects passed via invocation_state were previously merged + into every ModelStreamEvent by prepare(), causing repr() serialization of + ~131 KB Agent/Span objects on the wire (issue #1928). + """ + + class _NotSerializable: + pass + + not_serializable = _NotSerializable() + + async def test_event_loop(*args, **kwargs): + yield ModelStreamEvent({"data": "hello", "delta": {"text": "hello"}}) + yield EventLoopStopEvent("end_turn", {"role": "assistant", "content": []}, {}, {}) + + mock_event_loop_cycle.side_effect = test_event_loop + + agent = Agent() + events = await alist(agent.stream_async("hi", invocation_state={"obj": not_serializable})) + + stream_events = [e for e in events if "data" in e] + assert len(stream_events) == 1 + assert "obj" not in stream_events[0], "invocation_state must not appear in yielded stream events" + assert stream_events[0] == {"data": "hello", "delta": {"text": "hello"}} + + @pytest.mark.asyncio async def test_stream_async_raises_exceptions(mock_event_loop_cycle): mock_event_loop_cycle.side_effect = ValueError("Test exception")