Skip to content

Commit a39932a

Browse files
committed
feat(sessions): add get_user_state(app_name, user_id) to BaseSessionService
1 parent 03d6208 commit a39932a

6 files changed

Lines changed: 266 additions & 1 deletion

File tree

src/google/adk/sessions/base_session_service.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,46 @@ async def delete_session(
111111
) -> None:
112112
"""Deletes a session."""
113113

114+
async def get_user_state(
115+
self, *, app_name: str, user_id: str
116+
) -> dict[str, Any]:
117+
"""Returns the user-scoped state for the given app and user.
118+
119+
User state is keyed by ``(app_name, user_id)`` and shared across all
120+
sessions of the same user within the same app. The returned dictionary
121+
uses raw keys **without** the ``user:`` prefix (e.g. ``"my_key"`` rather
122+
than ``"user:my_key"``).
123+
124+
This method exists so that callers can read user state without holding an
125+
active ``session_id``. A common use case is bootstrapping context at the
126+
start of a new session before calling ``create_session``, which would
127+
otherwise require an expensive ``list_sessions`` call just to access
128+
user-scoped data.
129+
130+
Returns an empty dict when no user state has been stored for this
131+
``(app_name, user_id)`` combination.
132+
133+
Args:
134+
app_name: The name of the app.
135+
user_id: The ID of the user.
136+
137+
Returns:
138+
A dictionary of raw (un-prefixed) user-scoped key/value pairs, or an
139+
empty dict when no user state exists.
140+
141+
Raises:
142+
NotImplementedError: When the concrete ``BaseSessionService``
143+
implementation does not support reading user state independently of a
144+
session. Callers should catch this, then enumerate sessions via
145+
``list_sessions`` and call ``get_session`` on each result to access
146+
the merged state, or accept that user state is unavailable.
147+
"""
148+
raise NotImplementedError(
149+
f'{type(self).__name__} does not support get_user_state. '
150+
'To read user state, enumerate sessions via list_sessions and '
151+
'call get_session on each result to access the merged state.'
152+
)
153+
114154
async def append_event(self, session: Session, event: Event) -> Event:
115155
"""Appends an event to a session object."""
116156
if event.partial:

src/google/adk/sessions/database_session_service.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ async def _with_session_lock(
318318
else:
319319
self._session_lock_ref_count[lock_key] = remaining
320320

321-
async def _prepare_tables(self):
321+
async def _prepare_tables(self) -> None:
322322
"""Ensure database tables are ready for use.
323323
324324
This method is called lazily before each database operation. It checks the
@@ -616,6 +616,22 @@ async def delete_session(
616616
await sql_session.execute(stmt)
617617
await sql_session.commit()
618618

619+
@override
620+
async def get_user_state(
621+
self, *, app_name: str, user_id: str
622+
) -> dict[str, Any]:
623+
await self._prepare_tables()
624+
schema = self._get_schema_classes()
625+
async with self._rollback_on_exception_session(
626+
read_only=True
627+
) as sql_session:
628+
storage_user_state = await sql_session.get(
629+
schema.StorageUserState, (app_name, user_id)
630+
)
631+
if storage_user_state is None:
632+
return {}
633+
return dict(storage_user_state.state or {})
634+
619635
@override
620636
async def append_event(self, session: Session, event: Event) -> Event:
621637
await self._prepare_tables()

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def _delete_session_impl(
312312

313313
self.sessions[app_name][user_id].pop(session_id)
314314

315+
@override
316+
async def get_user_state(
317+
self, *, app_name: str, user_id: str
318+
) -> dict[str, Any]:
319+
return dict(self.user_state.get(app_name, {}).get(user_id, {}))
320+
315321
@override
316322
async def append_event(self, session: Session, event: Event) -> Event:
317323
if event.partial:

src/google/adk/sessions/sqlite_session_service.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,13 @@ async def delete_session(
359359
)
360360
await db.commit()
361361

362+
@override
363+
async def get_user_state(
364+
self, *, app_name: str, user_id: str
365+
) -> dict[str, Any]:
366+
async with self._get_db_connection() as db:
367+
return await self._get_user_state(db, app_name, user_id)
368+
362369
@override
363370
async def append_event(self, session: Session, event: Event) -> Event:
364371
if event.partial:

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,27 @@ async def delete_session(
269269
logger.error('Error deleting session %s: %s', session_id, e)
270270
raise
271271

272+
@override
273+
async def get_user_state(
274+
self, *, app_name: str, user_id: str
275+
) -> dict[str, Any]:
276+
"""Not supported by the Vertex AI Agent Engine backend.
277+
278+
The Vertex AI Agent Engine API does not expose user state independently of
279+
a session. To read user state, enumerate sessions via ``list_sessions``
280+
and call ``get_session`` on each result to access the merged state.
281+
282+
Raises:
283+
NotImplementedError: Always, because the Vertex AI Agent Engine API does
284+
not provide a way to query user state without a session.
285+
"""
286+
raise NotImplementedError(
287+
'VertexAiSessionService does not support get_user_state. '
288+
'The Vertex AI Agent Engine API does not expose user state '
289+
'independently of a session. To read user state, enumerate sessions '
290+
'via list_sessions and call get_session on each result.'
291+
)
292+
272293
@override
273294
async def append_event(self, session: Session, event: Event) -> Event:
274295
# Update the in-memory session.
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for BaseSessionService.get_user_state across concrete implementations."""
16+
17+
import enum
18+
19+
from google.adk.events.event import Event
20+
from google.adk.events.event_actions import EventActions
21+
from google.adk.sessions.base_session_service import BaseSessionService
22+
from google.adk.sessions.database_session_service import DatabaseSessionService
23+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
24+
from google.adk.sessions.sqlite_session_service import SqliteSessionService
25+
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
26+
import pytest
27+
28+
_APP = 'test-app'
29+
_OTHER_APP = 'other-app'
30+
_USER = 'user-42'
31+
_OTHER_USER = 'user-99'
32+
33+
34+
class SessionServiceType(enum.Enum):
35+
IN_MEMORY = 'IN_MEMORY'
36+
DATABASE = 'DATABASE'
37+
SQLITE = 'SQLITE'
38+
39+
40+
def _make_service(
41+
service_type: SessionServiceType, tmp_path=None
42+
) -> BaseSessionService:
43+
if service_type == SessionServiceType.DATABASE:
44+
return DatabaseSessionService('sqlite+aiosqlite:///:memory:')
45+
if service_type == SessionServiceType.SQLITE:
46+
return SqliteSessionService(str(tmp_path / 'sqlite.db'))
47+
return InMemorySessionService()
48+
49+
50+
@pytest.fixture(
51+
params=[
52+
SessionServiceType.IN_MEMORY,
53+
SessionServiceType.DATABASE,
54+
SessionServiceType.SQLITE,
55+
]
56+
)
57+
async def session_service(request, tmp_path):
58+
"""Provides a session service and closes database backends on teardown."""
59+
service = _make_service(request.param, tmp_path)
60+
yield service
61+
if isinstance(service, DatabaseSessionService):
62+
await service.close()
63+
64+
65+
async def test_get_user_state_returns_empty_dict_when_no_state_exists(
66+
session_service,
67+
):
68+
"""Returns {} when (app_name, user_id) has never had state written."""
69+
state = await session_service.get_user_state(app_name=_APP, user_id=_USER)
70+
assert state == {}
71+
72+
73+
async def test_get_user_state_returns_state_written_via_append_event(
74+
session_service,
75+
):
76+
"""State written with the user: prefix is returned without the prefix."""
77+
session = await session_service.create_session(app_name=_APP, user_id=_USER)
78+
await session_service.append_event(
79+
session,
80+
Event(
81+
author='system',
82+
actions=EventActions(
83+
state_delta={'user:profile': {'name': 'Alice'}, 'session_key': 1}
84+
),
85+
),
86+
)
87+
88+
state = await session_service.get_user_state(app_name=_APP, user_id=_USER)
89+
90+
assert state == {'profile': {'name': 'Alice'}}
91+
assert 'session_key' not in state
92+
93+
94+
async def test_get_user_state_is_not_visible_across_users(session_service):
95+
"""User state is scoped to (app_name, user_id) — other users see {}."""
96+
session = await session_service.create_session(app_name=_APP, user_id=_USER)
97+
await session_service.append_event(
98+
session,
99+
Event(
100+
author='system',
101+
actions=EventActions(state_delta={'user:secret': 'only-for-user-42'}),
102+
),
103+
)
104+
105+
other_state = await session_service.get_user_state(
106+
app_name=_APP, user_id=_OTHER_USER
107+
)
108+
assert other_state == {}
109+
110+
111+
async def test_get_user_state_is_not_visible_across_apps(session_service):
112+
"""User state is scoped to (app_name, user_id) — other apps see {}."""
113+
session = await session_service.create_session(app_name=_APP, user_id=_USER)
114+
await session_service.append_event(
115+
session,
116+
Event(
117+
author='system',
118+
actions=EventActions(state_delta={'user:data': 'only-app-a'}),
119+
),
120+
)
121+
122+
other_state = await session_service.get_user_state(
123+
app_name=_OTHER_APP, user_id=_USER
124+
)
125+
assert other_state == {}
126+
127+
128+
async def test_get_user_state_available_before_session_is_created(
129+
session_service,
130+
):
131+
"""Core use case: user state is readable without an active session_id."""
132+
first_session = await session_service.create_session(
133+
app_name=_APP, user_id=_USER
134+
)
135+
await session_service.append_event(
136+
first_session,
137+
Event(
138+
author='system',
139+
actions=EventActions(state_delta={'user:ctx': {'v': 1}}),
140+
),
141+
)
142+
143+
# Simulate a brand-new session_id (not yet created) — get_user_state must
144+
# still return the persisted user state.
145+
state = await session_service.get_user_state(app_name=_APP, user_id=_USER)
146+
assert state == {'ctx': {'v': 1}}
147+
148+
149+
async def test_get_user_state_reflects_latest_write(session_service):
150+
"""Subsequent writes overwrite earlier values under the same key."""
151+
session = await session_service.create_session(app_name=_APP, user_id=_USER)
152+
await session_service.append_event(
153+
session,
154+
Event(
155+
author='system',
156+
actions=EventActions(state_delta={'user:counter': 1}),
157+
),
158+
)
159+
await session_service.append_event(
160+
session,
161+
Event(
162+
author='system',
163+
actions=EventActions(state_delta={'user:counter': 2}),
164+
),
165+
)
166+
167+
state = await session_service.get_user_state(app_name=_APP, user_id=_USER)
168+
assert state['counter'] == 2
169+
170+
171+
async def test_vertex_ai_session_service_raises_not_implemented():
172+
"""VertexAiSessionService raises NotImplementedError for get_user_state."""
173+
service = VertexAiSessionService(project='proj', location='us-central1')
174+
with pytest.raises(NotImplementedError):
175+
await service.get_user_state(app_name=_APP, user_id=_USER)

0 commit comments

Comments
 (0)