Skip to content

Commit 364a96d

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add custom session id functionality to vertex ai session service
PiperOrigin-RevId: 891919371
1 parent 6ee0362 commit 364a96d

2 files changed

Lines changed: 25 additions & 16 deletions

File tree

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,11 @@ async def create_session(
116116
Returns:
117117
The created session.
118118
"""
119-
120-
if session_id:
121-
raise ValueError(
122-
'User-provided Session id is not supported for'
123-
' VertexAISessionService.'
124-
)
125-
126119
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
127120

128121
config = {'session_state': state} if state else {}
122+
if session_id:
123+
config['session_id'] = session_id
129124
config.update(kwargs)
130125
async with self._get_api_client() as api_client:
131126
api_response = await api_client.agent_engines.sessions.create(

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,10 @@ async def _create_session(
417417
self, name: str, user_id: str, config: dict[str, Any]
418418
):
419419
self.last_create_session_config = config
420-
new_session_id = '4'
420+
if 'session_id' in config:
421+
new_session_id = config['session_id']
422+
else:
423+
new_session_id = '4'
421424
self.session_dict[new_session_id] = {
422425
'name': (
423426
'projects/test-project/locations/test-location/'
@@ -436,7 +439,7 @@ async def _create_session(
436439
+ '/operations/111'
437440
),
438441
'done': True,
439-
'response': self.session_dict['4'],
442+
'response': self.session_dict[new_session_id],
440443
})
441444

442445
async def _list_events(self, name: str, **kwargs):
@@ -880,15 +883,26 @@ async def test_create_session():
880883

881884
@pytest.mark.asyncio
882885
@pytest.mark.usefixtures('mock_get_api_client')
883-
async def test_create_session_with_custom_session_id():
886+
@pytest.mark.parametrize('session_id', ['1', 'abc123'])
887+
async def test_create_session_with_custom_session_id(
888+
mock_api_client_instance: MockAsyncClient, session_id: str
889+
):
884890
session_service = mock_vertex_ai_session_service()
885891

886-
with pytest.raises(ValueError) as excinfo:
887-
await session_service.create_session(
888-
app_name='123', user_id='user', session_id='1'
889-
)
890-
assert str(excinfo.value) == (
891-
'User-provided Session id is not supported for VertexAISessionService.'
892+
mock_api_client_instance.event_dict[session_id] = (
893+
[],
894+
None,
895+
)
896+
897+
session = await session_service.create_session(
898+
app_name='123', user_id='user', session_id=session_id
899+
)
900+
assert session.id == session_id
901+
assert session.app_name == '123'
902+
assert session.user_id == 'user'
903+
assert session.last_update_time is not None
904+
assert session == await session_service.get_session(
905+
app_name='123', user_id='user', session_id=session_id
892906
)
893907

894908

0 commit comments

Comments
 (0)