@@ -324,7 +324,10 @@ async def _create_session(
324324 self , name : str , user_id : str , config : dict [str , Any ]
325325 ):
326326 self .last_create_session_config = config
327- new_session_id = '4'
327+ if 'session_id' in config :
328+ new_session_id = config ['session_id' ]
329+ else :
330+ new_session_id = '4'
328331 self .session_dict [new_session_id ] = {
329332 'name' : (
330333 'projects/test-project/locations/test-location/'
@@ -343,7 +346,7 @@ async def _create_session(
343346 + '/operations/111'
344347 ),
345348 'done' : True ,
346- 'response' : self .session_dict ['4' ],
349+ 'response' : self .session_dict [new_session_id ],
347350 })
348351
349352 async def _list_events (self , name : str , ** kwargs ):
@@ -769,15 +772,26 @@ async def test_create_session():
769772
770773@pytest .mark .asyncio
771774@pytest .mark .usefixtures ('mock_get_api_client' )
772- async def test_create_session_with_custom_session_id ():
775+ @pytest .mark .parametrize ('session_id' , ['1' , 'abc123' ])
776+ async def test_create_session_with_custom_session_id (
777+ mock_api_client_instance : MockAsyncClient , session_id : str
778+ ):
773779 session_service = mock_vertex_ai_session_service ()
774780
775- with pytest .raises (ValueError ) as excinfo :
776- await session_service .create_session (
777- app_name = '123' , user_id = 'user' , session_id = '1'
778- )
779- assert str (excinfo .value ) == (
780- 'User-provided Session id is not supported for VertexAISessionService.'
781+ mock_api_client_instance .event_dict [session_id ] = (
782+ [],
783+ None ,
784+ )
785+
786+ session = await session_service .create_session (
787+ app_name = '123' , user_id = 'user' , session_id = session_id
788+ )
789+ assert session .id == session_id
790+ assert session .app_name == '123'
791+ assert session .user_id == 'user'
792+ assert session .last_update_time is not None
793+ assert session == await session_service .get_session (
794+ app_name = '123' , user_id = 'user' , session_id = session_id
781795 )
782796
783797
0 commit comments