Skip to content

Commit 5a56f91

Browse files
author
Nishar
committed
fix: make temp: state keys accessible in session.state during invocation
Fixes #3047 temp: prefixed state keys (e.g. set via output_key='temp:...') were being dropped from session.state before lifecycle callbacks could read them. Root cause: BaseSessionService.__update_session_state() explicitly skipped keys starting with 'temp:', preventing them from ever reaching session.state. Changes: - base_session_service.py: Remove the temp: skip in __update_session_state() so all keys flow into the in-memory session.state. - in_memory_session_service.py: Strip temp: keys from event.actions.state_delta after updating the caller's session but before writing to the storage session, preventing temp keys from leaking into persisted state. - test_session_service.py: Update existing tests to reflect that temp: keys are now accessible in-memory, and add a dedicated regression test. Temp keys remain non-persistent: - DatabaseSessionService uses _extract_state_delta() which already filters temp: keys before writing to the database. - InMemorySessionService now strips temp: keys before updating shared state stores and the storage session. - Runner creates a deepcopy of session at invocation start, preventing temp keys from leaking across invocations.
1 parent c224626 commit 5a56f91

3 files changed

Lines changed: 56 additions & 8 deletions

File tree

src/google/adk/sessions/base_session_service.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,4 @@ def __update_session_state(self, session: Session, event: Event) -> None:
103103
"""Updates the session state based on the event."""
104104
if not event.actions or not event.actions.state_delta:
105105
return
106-
for key, value in event.actions.state_delta.items():
107-
if key.startswith(State.TEMP_PREFIX):
108-
continue
109-
session.state.update({key: value})
106+
session.state.update(event.actions.state_delta)

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,14 @@ async def append_event(self, session: Session, event: Event) -> Event:
259259
await super().append_event(session=session, event=event)
260260
session.last_update_time = event.timestamp
261261

262+
# Strip temp: keys before persisting to storage.
263+
if event.actions and event.actions.state_delta:
264+
event.actions.state_delta = {
265+
k: v
266+
for k, v in event.actions.state_delta.items()
267+
if not k.startswith(State.TEMP_PREFIX)
268+
}
269+
262270
# Update the storage session
263271
app_name = session.app_name
264272
user_id = session.user_id

tests/unittests/sessions/test_session_service.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,11 @@ async def test_session_state(service_type):
156156
)
157157
await session_service.append_event(session=session_11, event=event)
158158

159-
# User and app state is stored, temp state is filtered.
159+
# User and app state is stored, temp state is accessible in-memory.
160160
assert session_11.state.get('app:key') == 'value'
161161
assert session_11.state.get('key11') == 'value11_new'
162162
assert session_11.state.get('user:key1') == 'value1'
163-
assert not session_11.state.get('temp:key')
163+
assert session_11.state.get('temp:key') == 'temp'
164164

165165
session_12 = await session_service.get_session(
166166
app_name=app_name, user_id=user_id_1, session_id=session_id_12
@@ -218,11 +218,11 @@ async def test_create_new_session_will_merge_states(service_type):
218218
)
219219
await session_service.append_event(session=session_1, event=event)
220220

221-
# User and app state is stored, temp state is filtered.
221+
# User and app state is stored, temp state is accessible in-memory.
222222
assert session_1.state.get('app:key') == 'value'
223223
assert session_1.state.get('key1') == 'value1'
224224
assert session_1.state.get('user:key1') == 'value1'
225-
assert not session_1.state.get('temp:key')
225+
assert session_1.state.get('temp:key') == 'temp'
226226

227227
session_2 = await session_service.create_session(
228228
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
@@ -377,3 +377,46 @@ async def test_get_session_with_config(service_type):
377377
)
378378
events = session.events
379379
assert len(events) == num_test_events - after_timestamp + 1
380+
381+
382+
@pytest.mark.asyncio
383+
@pytest.mark.parametrize(
384+
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
385+
)
386+
async def test_temp_state_accessible_in_session_during_invocation(service_type):
387+
session_service = get_session_service(service_type)
388+
app_name = 'my_app'
389+
user_id = 'test_user'
390+
391+
session = await session_service.create_session(
392+
app_name=app_name, user_id=user_id
393+
)
394+
395+
event = Event(
396+
invocation_id='invocation_1',
397+
author='test_agent',
398+
content=types.Content(
399+
role='model', parts=[types.Part(text='Hello from agent')]
400+
),
401+
actions=EventActions(
402+
state_delta={
403+
'temp:agent_output': 'Hello from agent',
404+
'temp:oauth_token': 'bearer_abc123',
405+
'persistent_key': 'should_persist',
406+
}
407+
),
408+
)
409+
await session_service.append_event(session=session, event=event)
410+
411+
# temp: keys are accessible in-memory during the same invocation.
412+
assert session.state.get('temp:agent_output') == 'Hello from agent'
413+
assert session.state.get('temp:oauth_token') == 'bearer_abc123'
414+
assert session.state.get('persistent_key') == 'should_persist'
415+
416+
# temp: keys are not persisted to storage.
417+
refetched = await session_service.get_session(
418+
app_name=app_name, user_id=user_id, session_id=session.id
419+
)
420+
assert not refetched.state.get('temp:agent_output')
421+
assert not refetched.state.get('temp:oauth_token')
422+
assert refetched.state.get('persistent_key') == 'should_persist'

0 commit comments

Comments
 (0)