Skip to content

Commit 5e39fa5

Browse files
committed
fix: filter temp-scoped state keys in SSE event generator
Temp state keys (prefix `temp:`) can contain non-serializable objects such as `FunctionTool` instances stored by `_call_llm_node`. `_trim_temp_delta_state()` already filters these for persistence in `append_event()`, but the SSE event generator serializes events before they reach that path, causing `PydanticSerializationError` when `streaming=true`. Apply the same `State.TEMP_PREFIX` filtering in the SSE generator before `model_dump_json()`. Fixes #5051
1 parent ccac461 commit 5e39fa5

2 files changed

Lines changed: 95 additions & 0 deletions

File tree

src/google/adk/cli/adk_web_server.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
from ..runners import Runner
9393
from ..sessions.base_session_service import BaseSessionService
9494
from ..sessions.session import Session
95+
from ..sessions.state import State
9596
from ..utils.agent_info import AgentInfo
9697
from ..utils.agent_info import get_agents_dict
9798
from ..utils.context_utils import Aclosing
@@ -1915,6 +1916,21 @@ async def event_generator():
19151916
events_to_stream = [content_event, artifact_event]
19161917

19171918
for event_to_stream in events_to_stream:
1919+
# Filter temp-scoped state keys before SSE serialization.
1920+
# Temp state (prefix "temp:") can contain non-serializable
1921+
# objects such as FunctionTool instances stored by
1922+
# _call_llm_node. _trim_temp_delta_state() handles this
1923+
# for persistence in append_event(), but SSE events are
1924+
# serialized before reaching that path.
1925+
if (
1926+
event_to_stream.actions
1927+
and event_to_stream.actions.state_delta
1928+
):
1929+
event_to_stream.actions.state_delta = {
1930+
k: v
1931+
for k, v in event_to_stream.actions.state_delta.items()
1932+
if not k.startswith(State.TEMP_PREFIX)
1933+
}
19181934
sse_event = event_to_stream.model_dump_json(
19191935
exclude_none=True,
19201936
by_alias=True,

tests/unittests/cli/test_fast_api.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,85 @@ async def run_async_raises(self, **kwargs):
13911391
assert sse_events == [{"error": "boom"}]
13921392

13931393

1394+
def test_agent_run_sse_filters_temp_state_keys(
1395+
test_app, create_test_session, monkeypatch
1396+
):
1397+
"""Test /run_sse strips temp-scoped state keys before serialization.
1398+
1399+
Temp state (e.g. ``temp:tools_dict``) can hold non-serializable objects
1400+
such as ``FunctionTool`` instances. ``_trim_temp_delta_state`` filters
1401+
these for persistence, but SSE events are serialized earlier. This test
1402+
verifies that the SSE generator applies the same filtering so that
1403+
``model_dump_json`` does not crash.
1404+
1405+
Regression test for https://github.com/google/adk-python/issues/5051
1406+
"""
1407+
info = create_test_session
1408+
1409+
# An object that is intentionally not JSON-serializable, mimicking
1410+
# the FunctionTool instances that _call_llm_node stores in temp state.
1411+
class _NotSerializable:
1412+
pass
1413+
1414+
async def run_async_with_temp_state(
1415+
self,
1416+
*,
1417+
user_id: str,
1418+
session_id: str,
1419+
invocation_id: Optional[str] = None,
1420+
new_message: Optional[types.Content] = None,
1421+
state_delta: Optional[dict[str, Any]] = None,
1422+
run_config: Optional[RunConfig] = None,
1423+
):
1424+
del user_id, session_id, invocation_id, new_message, state_delta, run_config
1425+
yield Event(
1426+
author="dummy agent",
1427+
invocation_id="invocation_id",
1428+
content=types.Content(
1429+
role="model", parts=[types.Part(text="hello")]
1430+
),
1431+
actions=EventActions(
1432+
state_delta={
1433+
"user_request": "hi",
1434+
"temp:tools_dict": {"greet": _NotSerializable()},
1435+
"temp:other": _NotSerializable(),
1436+
}
1437+
),
1438+
)
1439+
1440+
monkeypatch.setattr(Runner, "run_async", run_async_with_temp_state)
1441+
1442+
payload = {
1443+
"app_name": info["app_name"],
1444+
"user_id": info["user_id"],
1445+
"session_id": info["session_id"],
1446+
"new_message": {"role": "user", "parts": [{"text": "Hello agent"}]},
1447+
"streaming": True,
1448+
}
1449+
1450+
response = test_app.post("/run_sse", json=payload)
1451+
assert response.status_code == 200
1452+
1453+
sse_events = [
1454+
json.loads(line.removeprefix("data: "))
1455+
for line in response.text.splitlines()
1456+
if line.startswith("data: ")
1457+
]
1458+
1459+
assert len(sse_events) == 1
1460+
event_data = sse_events[0]
1461+
1462+
# Content should be intact.
1463+
assert event_data["content"]["parts"][0]["text"] == "hello"
1464+
1465+
# Non-temp state key should survive.
1466+
assert event_data["actions"]["stateDelta"]["user_request"] == "hi"
1467+
1468+
# Temp-scoped keys must be stripped.
1469+
assert "temp:tools_dict" not in event_data["actions"]["stateDelta"]
1470+
assert "temp:other" not in event_data["actions"]["stateDelta"]
1471+
1472+
13941473
def test_list_artifact_names(test_app, create_test_session):
13951474
"""Test listing artifact names for a session."""
13961475
info = create_test_session

0 commit comments

Comments
 (0)