@@ -417,7 +417,10 @@ async def _create_session(
417417 self , name : str , user_id : str , config : dict [str , Any ]
418418 ):
419419 self .last_create_session_config = config
420- new_session_id = '4'
420+ if 'session_id' in config :
421+ new_session_id = config ['session_id' ]
422+ else :
423+ new_session_id = '4'
421424 self .session_dict [new_session_id ] = {
422425 'name' : (
423426 'projects/test-project/locations/test-location/'
@@ -436,7 +439,7 @@ async def _create_session(
436439 + '/operations/111'
437440 ),
438441 'done' : True ,
439- 'response' : self .session_dict ['4' ],
442+ 'response' : self .session_dict [new_session_id ],
440443 })
441444
442445 async def _list_events (self , name : str , ** kwargs ):
@@ -880,15 +883,26 @@ async def test_create_session():
880883
881884@pytest .mark .asyncio
882885@pytest .mark .usefixtures ('mock_get_api_client' )
883- async def test_create_session_with_custom_session_id ():
886+ @pytest .mark .parametrize ('session_id' , ['1' , 'abc123' ])
887+ async def test_create_session_with_custom_session_id (
888+ mock_api_client_instance : MockAsyncClient , session_id : str
889+ ):
884890 session_service = mock_vertex_ai_session_service ()
885891
886- with pytest .raises (ValueError ) as excinfo :
887- await session_service .create_session (
888- app_name = '123' , user_id = 'user' , session_id = '1'
889- )
890- assert str (excinfo .value ) == (
891- 'User-provided Session id is not supported for VertexAISessionService.'
892+ mock_api_client_instance .event_dict [session_id ] = (
893+ [],
894+ None ,
895+ )
896+
897+ session = await session_service .create_session (
898+ app_name = '123' , user_id = 'user' , session_id = session_id
899+ )
900+ assert session .id == session_id
901+ assert session .app_name == '123'
902+ assert session .user_id == 'user'
903+ assert session .last_update_time is not None
904+ assert session == await session_service .get_session (
905+ app_name = '123' , user_id = 'user' , session_id = session_id
892906 )
893907
894908
0 commit comments