Skip to content

Commit dcde8ca

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add custom session id functionality to vertex ai session service
PiperOrigin-RevId: 891919371
1 parent 8850916 commit dcde8ca

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,11 @@ async def create_session(
115115
Returns:
116116
The created session.
117117
"""
118-
119-
if session_id:
120-
raise ValueError(
121-
'User-provided Session id is not supported for'
122-
' VertexAISessionService.'
123-
)
124-
125118
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
126119

127120
config = {'session_state': state} if state else {}
121+
if session_id:
122+
config['session_id'] = session_id
128123
config.update(kwargs)
129124
async with self._get_api_client() as api_client:
130125
api_response = await api_client.agent_engines.sessions.create(

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,10 @@ async def _create_session(
324324
self, name: str, user_id: str, config: dict[str, Any]
325325
):
326326
self.last_create_session_config = config
327-
new_session_id = '4'
327+
if 'session_id' in config:
328+
new_session_id = config['session_id']
329+
else:
330+
new_session_id = '4'
328331
self.session_dict[new_session_id] = {
329332
'name': (
330333
'projects/test-project/locations/test-location/'
@@ -343,7 +346,7 @@ async def _create_session(
343346
+ '/operations/111'
344347
),
345348
'done': True,
346-
'response': self.session_dict['4'],
349+
'response': self.session_dict[new_session_id],
347350
})
348351

349352
async def _list_events(self, name: str, **kwargs):
@@ -769,15 +772,26 @@ async def test_create_session():
769772

770773
@pytest.mark.asyncio
771774
@pytest.mark.usefixtures('mock_get_api_client')
772-
async def test_create_session_with_custom_session_id():
775+
@pytest.mark.parametrize('session_id', ['1', 'abc123'])
776+
async def test_create_session_with_custom_session_id(
777+
mock_api_client_instance: MockAsyncClient, session_id: str
778+
):
773779
session_service = mock_vertex_ai_session_service()
774780

775-
with pytest.raises(ValueError) as excinfo:
776-
await session_service.create_session(
777-
app_name='123', user_id='user', session_id='1'
778-
)
779-
assert str(excinfo.value) == (
780-
'User-provided Session id is not supported for VertexAISessionService.'
781+
mock_api_client_instance.event_dict[session_id] = (
782+
[],
783+
None,
784+
)
785+
786+
session = await session_service.create_session(
787+
app_name='123', user_id='user', session_id=session_id
788+
)
789+
assert session.id == session_id
790+
assert session.app_name == '123'
791+
assert session.user_id == 'user'
792+
assert session.last_update_time is not None
793+
assert session == await session_service.get_session(
794+
app_name='123', user_id='user', session_id=session_id
781795
)
782796

783797

0 commit comments

Comments
 (0)