Skip to content

Commit 9252a85

Browse files
committed
fix(redis): filter temp state from persistence and preserve in-memory
1 parent 090b407 commit 9252a85

2 files changed

Lines changed: 21 additions & 2 deletions

File tree

src/google/adk_community/sessions/redis_session_service.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@
4343

4444
def _session_serializer(obj: Session) -> bytes:
4545
"""Serialize ADK Session to JSON bytes."""
46-
return orjson.dumps(obj.model_dump(), default=_json_serializer)
46+
data = obj.model_dump()
47+
if "state" in data:
48+
data["state"] = {
49+
k: v for k, v in data["state"].items()
50+
if not k.startswith(State.TEMP_PREFIX)
51+
}
52+
return orjson.dumps(data, default=_json_serializer)
4753

4854

4955
class RedisKeys:

tests/unittests/sessions/test_redis_session_service.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ async def test_session_state_management(self, redis_service):
255255
assert session.state.get("app:key") == "app_value"
256256
assert session.state.get("user:key1") == "user_value"
257257
assert session.state.get("initial_key") == "updated_value"
258-
assert session.state.get("temp:key") is None # Temp state filtered
258+
assert session.state.get("temp:key") == "temp_value" # Ephemeral state kept in memory
259259

260260
pipeline_mock = redis_service.cache.pipeline.return_value
261261
pipe_mock = await pipeline_mock.__aenter__()
@@ -264,6 +264,19 @@ async def test_session_state_management(self, redis_service):
264264
"user:test_app:test_user", "key1", orjson.dumps("user_value")
265265
)
266266

267+
# Verify temp state was filtered out from the serialized session saved to Redis
268+
set_calls = pipe_mock.set.call_args_list
269+
session_bytes = None
270+
for call in set_calls:
271+
args = call[0]
272+
if args[0].startswith("session:"):
273+
session_bytes = args[1]
274+
break
275+
276+
assert session_bytes is not None
277+
stored_session = orjson.loads(session_bytes)
278+
assert "temp:key" not in stored_session.get("state", {})
279+
267280
@pytest.mark.asyncio
268281
async def test_append_event_with_bytes(self, redis_service):
269282
"""Test appending events with binary content and serialization roundtrip."""

0 commit comments

Comments
 (0)