Skip to content

Commit 44f5422

Browse files
nicolasmotaDeanChensj
authored andcommitted
feat(sessions): implement get_user_state method in PerAgentDatabaseSessionService and add corresponding unit tests
- Added `get_user_state(app_name, user_id)` method to `PerAgentDatabaseSessionService` for retrieving user state. - Created unit tests to validate the functionality of `get_user_state`, ensuring it returns the correct state for existing users and an empty dictionary for non-existent users. - Removed obsolete `test_get_user_state.py` file as its functionality is now covered by the new tests.
1 parent 44997e6 commit 44f5422

4 files changed

Lines changed: 157 additions & 175 deletions

File tree

src/google/adk/cli/utils/local_storage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import asyncio
1919
import logging
2020
from pathlib import Path
21+
from typing import Any
2122
from typing import Mapping
2223
from typing import Optional
2324

@@ -205,6 +206,13 @@ async def delete_session(
205206
app_name=app_name, user_id=user_id, session_id=session_id
206207
)
207208

209+
@override
210+
async def get_user_state(
211+
self, *, app_name: str, user_id: str
212+
) -> dict[str, Any]:
213+
service = await self._get_service(app_name)
214+
return await service.get_user_state(app_name=app_name, user_id=user_id)
215+
208216
@override
209217
async def append_event(self, session: Session, event: Event) -> Event:
210218
service = await self._get_service(session.app_name)

tests/unittests/cli/utils/test_local_storage.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from google.adk.cli.utils.local_storage import create_local_database_session_service
2020
from google.adk.cli.utils.local_storage import create_local_session_service
2121
from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService
22+
from google.adk.events.event import Event
23+
from google.adk.events.event_actions import EventActions
2224
from google.adk.sessions.sqlite_session_service import SqliteSessionService
2325
import pytest
2426

@@ -90,3 +92,28 @@ def test_create_local_database_session_service_returns_sqlite(
9092
service = create_local_database_session_service(base_dir=tmp_path)
9193

9294
assert isinstance(service, SqliteSessionService)
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_per_agent_session_service_get_user_state(tmp_path: Path) -> None:
99+
agent_a = tmp_path / 'agent_a'
100+
agent_b = tmp_path / 'agent_b'
101+
agent_a.mkdir()
102+
agent_b.mkdir()
103+
104+
service = PerAgentDatabaseSessionService(agents_root=tmp_path)
105+
106+
session_a = await service.create_session(app_name='agent_a', user_id='user_a')
107+
await service.append_event(
108+
session_a,
109+
Event(
110+
author='system',
111+
actions=EventActions(state_delta={'user:profile': {'name': 'Alice'}}),
112+
),
113+
)
114+
115+
state_a = await service.get_user_state(app_name='agent_a', user_id='user_a')
116+
state_b = await service.get_user_state(app_name='agent_b', user_id='user_b')
117+
118+
assert state_a == {'profile': {'name': 'Alice'}}
119+
assert state_b == {}

tests/unittests/sessions/test_get_user_state.py

Lines changed: 0 additions & 175 deletions
This file was deleted.

tests/unittests/sessions/test_session_service.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.adk.sessions.database_session_service import DatabaseSessionService
3131
from google.adk.sessions.in_memory_session_service import InMemorySessionService
3232
from google.adk.sessions.sqlite_session_service import SqliteSessionService
33+
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
3334
from google.genai import types
3435
import pytest
3536
from sqlalchemy import delete
@@ -1650,3 +1651,124 @@ async def tracking_fn(**kwargs):
16501651
finally:
16511652
database_session_service._select_required_state = original_fn
16521653
await service.close()
1654+
1655+
1656+
@pytest.mark.asyncio
1657+
async def test_get_user_state_returns_empty_dict_when_no_state_exists(
1658+
session_service,
1659+
):
1660+
state = await session_service.get_user_state(
1661+
app_name='my_app', user_id='u1'
1662+
)
1663+
assert state == {}
1664+
1665+
1666+
@pytest.mark.asyncio
1667+
async def test_get_user_state_returns_state_written_via_append_event(
1668+
session_service,
1669+
):
1670+
session = await session_service.create_session(
1671+
app_name='my_app', user_id='u1'
1672+
)
1673+
await session_service.append_event(
1674+
session,
1675+
Event(
1676+
author='system',
1677+
actions=EventActions(
1678+
state_delta={'user:profile': {'name': 'Alice'}, 'session_key': 1}
1679+
),
1680+
),
1681+
)
1682+
1683+
state = await session_service.get_user_state(app_name='my_app', user_id='u1')
1684+
1685+
assert state == {'profile': {'name': 'Alice'}}
1686+
assert 'session_key' not in state
1687+
1688+
1689+
@pytest.mark.asyncio
1690+
async def test_get_user_state_is_not_visible_across_users(session_service):
1691+
session = await session_service.create_session(
1692+
app_name='my_app', user_id='u1'
1693+
)
1694+
await session_service.append_event(
1695+
session,
1696+
Event(
1697+
author='system',
1698+
actions=EventActions(state_delta={'user:secret': 'only-for-u1'}),
1699+
),
1700+
)
1701+
1702+
other_state = await session_service.get_user_state(
1703+
app_name='my_app', user_id='u2'
1704+
)
1705+
assert other_state == {}
1706+
1707+
1708+
@pytest.mark.asyncio
1709+
async def test_get_user_state_is_not_visible_across_apps(session_service):
1710+
session = await session_service.create_session(
1711+
app_name='my_app', user_id='u1'
1712+
)
1713+
await session_service.append_event(
1714+
session,
1715+
Event(
1716+
author='system',
1717+
actions=EventActions(state_delta={'user:data': 'only-app-a'}),
1718+
),
1719+
)
1720+
1721+
other_state = await session_service.get_user_state(
1722+
app_name='other_app', user_id='u1'
1723+
)
1724+
assert other_state == {}
1725+
1726+
1727+
@pytest.mark.asyncio
1728+
async def test_get_user_state_available_before_session_is_created(
1729+
session_service,
1730+
):
1731+
first_session = await session_service.create_session(
1732+
app_name='my_app', user_id='u1'
1733+
)
1734+
await session_service.append_event(
1735+
first_session,
1736+
Event(
1737+
author='system',
1738+
actions=EventActions(state_delta={'user:ctx': {'v': 1}}),
1739+
),
1740+
)
1741+
1742+
state = await session_service.get_user_state(app_name='my_app', user_id='u1')
1743+
assert state == {'ctx': {'v': 1}}
1744+
1745+
1746+
@pytest.mark.asyncio
1747+
async def test_get_user_state_reflects_latest_write(session_service):
1748+
session = await session_service.create_session(
1749+
app_name='my_app', user_id='u1'
1750+
)
1751+
await session_service.append_event(
1752+
session,
1753+
Event(
1754+
author='system',
1755+
actions=EventActions(state_delta={'user:counter': 1}),
1756+
),
1757+
)
1758+
await session_service.append_event(
1759+
session,
1760+
Event(
1761+
author='system',
1762+
actions=EventActions(state_delta={'user:counter': 2}),
1763+
),
1764+
)
1765+
1766+
state = await session_service.get_user_state(app_name='my_app', user_id='u1')
1767+
assert state['counter'] == 2
1768+
1769+
1770+
@pytest.mark.asyncio
1771+
async def test_vertex_ai_session_service_raises_not_implemented_for_get_user_state():
1772+
service = VertexAiSessionService(project='proj', location='us-central1')
1773+
with pytest.raises(NotImplementedError):
1774+
await service.get_user_state(app_name='my_app', user_id='u1')

0 commit comments

Comments
 (0)