|
30 | 30 | from google.adk.sessions.database_session_service import DatabaseSessionService |
31 | 31 | from google.adk.sessions.in_memory_session_service import InMemorySessionService |
32 | 32 | from google.adk.sessions.sqlite_session_service import SqliteSessionService |
| 33 | +from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService |
33 | 34 | from google.genai import types |
34 | 35 | import pytest |
35 | 36 | from sqlalchemy import delete |
@@ -1650,3 +1651,122 @@ async def tracking_fn(**kwargs): |
1650 | 1651 | finally: |
1651 | 1652 | database_session_service._select_required_state = original_fn |
1652 | 1653 | await service.close() |
| 1654 | + |
| 1655 | + |
| 1656 | +@pytest.mark.asyncio |
| 1657 | +async def test_get_user_state_returns_empty_dict_when_no_state_exists( |
| 1658 | + session_service, |
| 1659 | +): |
| 1660 | + state = await session_service.get_user_state(app_name='my_app', user_id='u1') |
| 1661 | + assert state == {} |
| 1662 | + |
| 1663 | + |
| 1664 | +@pytest.mark.asyncio |
| 1665 | +async def test_get_user_state_returns_state_written_via_append_event( |
| 1666 | + session_service, |
| 1667 | +): |
| 1668 | + session = await session_service.create_session( |
| 1669 | + app_name='my_app', user_id='u1' |
| 1670 | + ) |
| 1671 | + await session_service.append_event( |
| 1672 | + session, |
| 1673 | + Event( |
| 1674 | + author='system', |
| 1675 | + actions=EventActions( |
| 1676 | + state_delta={'user:profile': {'name': 'Alice'}, 'session_key': 1} |
| 1677 | + ), |
| 1678 | + ), |
| 1679 | + ) |
| 1680 | + |
| 1681 | + state = await session_service.get_user_state(app_name='my_app', user_id='u1') |
| 1682 | + |
| 1683 | + assert state == {'profile': {'name': 'Alice'}} |
| 1684 | + assert 'session_key' not in state |
| 1685 | + |
| 1686 | + |
| 1687 | +@pytest.mark.asyncio |
| 1688 | +async def test_get_user_state_is_not_visible_across_users(session_service): |
| 1689 | + session = await session_service.create_session( |
| 1690 | + app_name='my_app', user_id='u1' |
| 1691 | + ) |
| 1692 | + await session_service.append_event( |
| 1693 | + session, |
| 1694 | + Event( |
| 1695 | + author='system', |
| 1696 | + actions=EventActions(state_delta={'user:secret': 'only-for-u1'}), |
| 1697 | + ), |
| 1698 | + ) |
| 1699 | + |
| 1700 | + other_state = await session_service.get_user_state( |
| 1701 | + app_name='my_app', user_id='u2' |
| 1702 | + ) |
| 1703 | + assert other_state == {} |
| 1704 | + |
| 1705 | + |
| 1706 | +@pytest.mark.asyncio |
| 1707 | +async def test_get_user_state_is_not_visible_across_apps(session_service): |
| 1708 | + session = await session_service.create_session( |
| 1709 | + app_name='my_app', user_id='u1' |
| 1710 | + ) |
| 1711 | + await session_service.append_event( |
| 1712 | + session, |
| 1713 | + Event( |
| 1714 | + author='system', |
| 1715 | + actions=EventActions(state_delta={'user:data': 'only-app-a'}), |
| 1716 | + ), |
| 1717 | + ) |
| 1718 | + |
| 1719 | + other_state = await session_service.get_user_state( |
| 1720 | + app_name='other_app', user_id='u1' |
| 1721 | + ) |
| 1722 | + assert other_state == {} |
| 1723 | + |
| 1724 | + |
| 1725 | +@pytest.mark.asyncio |
| 1726 | +async def test_get_user_state_available_before_session_is_created( |
| 1727 | + session_service, |
| 1728 | +): |
| 1729 | + first_session = await session_service.create_session( |
| 1730 | + app_name='my_app', user_id='u1' |
| 1731 | + ) |
| 1732 | + await session_service.append_event( |
| 1733 | + first_session, |
| 1734 | + Event( |
| 1735 | + author='system', |
| 1736 | + actions=EventActions(state_delta={'user:ctx': {'v': 1}}), |
| 1737 | + ), |
| 1738 | + ) |
| 1739 | + |
| 1740 | + state = await session_service.get_user_state(app_name='my_app', user_id='u1') |
| 1741 | + assert state == {'ctx': {'v': 1}} |
| 1742 | + |
| 1743 | + |
| 1744 | +@pytest.mark.asyncio |
| 1745 | +async def test_get_user_state_reflects_latest_write(session_service): |
| 1746 | + session = await session_service.create_session( |
| 1747 | + app_name='my_app', user_id='u1' |
| 1748 | + ) |
| 1749 | + await session_service.append_event( |
| 1750 | + session, |
| 1751 | + Event( |
| 1752 | + author='system', |
| 1753 | + actions=EventActions(state_delta={'user:counter': 1}), |
| 1754 | + ), |
| 1755 | + ) |
| 1756 | + await session_service.append_event( |
| 1757 | + session, |
| 1758 | + Event( |
| 1759 | + author='system', |
| 1760 | + actions=EventActions(state_delta={'user:counter': 2}), |
| 1761 | + ), |
| 1762 | + ) |
| 1763 | + |
| 1764 | + state = await session_service.get_user_state(app_name='my_app', user_id='u1') |
| 1765 | + assert state['counter'] == 2 |
| 1766 | + |
| 1767 | + |
| 1768 | +@pytest.mark.asyncio |
| 1769 | +async def test_vertex_ai_session_service_raises_not_implemented_for_get_user_state(): |
| 1770 | + service = VertexAiSessionService(project='proj', location='us-central1') |
| 1771 | + with pytest.raises(NotImplementedError): |
| 1772 | + await service.get_user_state(app_name='my_app', user_id='u1') |
0 commit comments