diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133de..585d9a7d3 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -828,12 +828,16 @@ async def stream_async( events = self._run_loop(messages, merged_state, structured_output_model, structured_output_prompt) async for event in events: + # Snapshot the event data before prepare() merges invocation_state + # into the dict. The callback_handler receives the full merged dict + # for backward compatibility, but stream_async() callers only see + # the serializable event fields. + event_data = event.as_dict() event.prepare(invocation_state=merged_state) if event.is_callback_event: - as_dict = event.as_dict() - callback_handler(**as_dict) - yield as_dict + callback_handler(**event.as_dict()) + yield event_data result = AgentResult(*event["stop"]) callback_handler(result=result) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 02c367ccc..9ab34dea6 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -46,6 +46,24 @@ def mock_sleep(): "request_state": {}, } +# Keys that prepare() merges from invocation_state. stream_async() no longer includes +# these in yielded events; callback_handler still receives them for backward compat. +_INVOCATION_STATE_KEYS = frozenset(any_props.keys()) | frozenset( + { + "event_loop_parent_cycle_id", + "messages", + "model", + "system_prompt", + "tool_config", + } +) + + +def _strip_state(events: list[dict], *user_keys: str) -> list[dict]: + """Return events with invocation_state fields removed (matches what stream_async() yields).""" + keys_to_remove = _INVOCATION_STATE_KEYS | set(user_keys) + return [{k: v for k, v in e.items() if k not in keys_to_remove} for e in events] + @pytest.mark.asyncio async def test_stream_e2e_success(alist): @@ -317,7 +335,10 @@ async def test_stream_e2e_success(alist): ), }, ] - assert tru_events == exp_events + # stream_async() yields events without invocation_state; callback_handler receives + # the full merged dict. Verify both independently. + exp_yield_events = _strip_state(exp_events, "arg1") + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list @@ -381,7 +402,8 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): ), }, ] - assert tru_events == exp_events + exp_yield_events = _strip_state(exp_events, "arg1") + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list @@ -459,7 +481,8 @@ async def test_stream_e2e_reasoning_redacted_content(alist): ), }, ] - assert tru_events == exp_events + exp_yield_events = _strip_state(exp_events) + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list @@ -514,7 +537,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end( {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, ] - assert tru_events == exp_events + exp_yield_events = _strip_state(exp_events, "arg1") + assert tru_events == exp_yield_events exp_calls = [call(**event) for event in exp_events] act_calls = mock_callback.call_args_list diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5a3cce11c..91d821c4f 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1080,8 +1080,11 @@ async def test_event_loop(*args, **kwargs): stream = agent.stream_async("test message", callback_handler=mock_callback) tru_events = await alist(stream) + + # stream_async() yields events without invocation_state merged in; invocation_state + # is only passed to the callback_handler for backward compat. exp_events = [ - {"init_event_loop": True, "callback_handler": mock_callback}, + {"init_event_loop": True}, {"data": "First chunk"}, {"data": "Second chunk"}, {"complete": True, "data": "Final chunk"}, @@ -1096,8 +1099,24 @@ async def test_event_loop(*args, **kwargs): ] assert tru_events == exp_events - exp_calls = [unittest.mock.call(**event) for event in exp_events] - mock_callback.assert_has_calls(exp_calls) + # The callback_handler receives the fully-merged dict (including invocation_state). + exp_callback_calls = [ + unittest.mock.call(**{"init_event_loop": True, "callback_handler": mock_callback}), + unittest.mock.call(**{"data": "First chunk"}), + unittest.mock.call(**{"data": "Second chunk"}), + unittest.mock.call(**{"complete": True, "data": "Final chunk"}), + unittest.mock.call( + **{ + "result": AgentResult( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + state={}, + ) + } + ), + ] + mock_callback.assert_has_calls(exp_callback_calls) @pytest.mark.asyncio @@ -1196,7 +1215,7 @@ async def check_invocation_state(**kwargs): tru_events = await alist(stream) exp_events = [ - {"init_event_loop": True, "some_value": "a_value"}, + {"init_event_loop": True}, { "result": AgentResult( stop_reason="stop", @@ -1211,6 +1230,35 @@ async def check_invocation_state(**kwargs): assert mock_event_loop_cycle.call_count == 1 +@pytest.mark.asyncio +async def test_stream_async_does_not_yield_invocation_state(mock_event_loop_cycle, alist): + """stream_async() must not include invocation_state in yielded events. + + Non-serializable objects passed via invocation_state were previously merged + into every ModelStreamEvent by prepare(), causing repr() serialization of + ~131 KB Agent/Span objects on the wire (issue #1928). + """ + + class _NotSerializable: + pass + + not_serializable = _NotSerializable() + + async def test_event_loop(*args, **kwargs): + yield ModelStreamEvent({"data": "hello", "delta": {"text": "hello"}}) + yield EventLoopStopEvent("end_turn", {"role": "assistant", "content": []}, {}, {}) + + mock_event_loop_cycle.side_effect = test_event_loop + + agent = Agent() + events = await alist(agent.stream_async("hi", invocation_state={"obj": not_serializable})) + + stream_events = [e for e in events if "data" in e] + assert len(stream_events) == 1 + assert "obj" not in stream_events[0], "invocation_state must not appear in yielded stream events" + assert stream_events[0] == {"data": "hello", "delta": {"text": "hello"}} + + @pytest.mark.asyncio async def test_stream_async_raises_exceptions(mock_event_loop_cycle): mock_event_loop_cycle.side_effect = ValueError("Test exception")