@@ -747,6 +747,78 @@ async def mock_receive_2():
747747 assert invocation_context .live_session_resumption_handle == 'test_handle'
748748
749749
750+ @pytest .mark .asyncio
751+ async def test_run_live_reconnect_preserves_nontransparent_resumption ():
752+ """Test that reconnect does not force transparent resumption."""
753+ from google .adk .agents .live_request_queue import LiveRequestQueue
754+ from websockets .exceptions import ConnectionClosed
755+
756+ real_model = Gemini ()
757+ mock_connection = mock .AsyncMock ()
758+
759+ async def mock_receive ():
760+ yield LlmResponse (
761+ live_session_resumption_update = types .LiveServerSessionResumptionUpdate (
762+ new_handle = 'test_handle'
763+ )
764+ )
765+ raise ConnectionClosed (None , None )
766+
767+ mock_connection .receive = mock .Mock (side_effect = mock_receive )
768+
769+ agent = Agent (name = 'test_agent' , model = real_model )
770+ invocation_context = await testing_utils .create_invocation_context (
771+ agent = agent
772+ )
773+ invocation_context .live_request_queue = LiveRequestQueue ()
774+
775+ flow = BaseLlmFlowForTesting ()
776+
777+ async def mock_preprocess (ctx , req ):
778+ req .live_connect_config .session_resumption = types .SessionResumptionConfig (
779+ transparent = False
780+ )
781+ if False :
782+ yield
783+
784+ with mock .patch .object (
785+ flow , '_preprocess_async' , side_effect = mock_preprocess
786+ ):
787+ with mock .patch .object (flow , '_send_to_model' , new_callable = AsyncMock ):
788+ mock_connection_2 = mock .AsyncMock ()
789+
790+ class StopError (Exception ):
791+ pass
792+
793+ async def mock_receive_2 ():
794+ yield LlmResponse (
795+ content = types .Content (parts = [types .Part .from_text (text = 'hi' )])
796+ )
797+ raise StopError ('stop' )
798+
799+ mock_connection_2 .receive = mock .Mock (side_effect = mock_receive_2 )
800+
801+ mock_aenter = mock .AsyncMock ()
802+ mock_aenter .side_effect = [mock_connection , mock_connection_2 ]
803+
804+ with mock .patch (
805+ 'google.adk.models.google_llm.Gemini.connect'
806+ ) as mock_connect :
807+ mock_connect .return_value .__aenter__ = mock_aenter
808+
809+ try :
810+ async for _ in flow .run_live (invocation_context ):
811+ pass
812+ except StopError :
813+ pass
814+
815+ reconnect_request = mock_connect .call_args_list [1 ].args [0 ]
816+ assert (
817+ reconnect_request .live_connect_config .session_resumption .transparent
818+ is False
819+ )
820+
821+
750822@pytest .mark .asyncio
751823async def test_run_live_skips_send_history_on_resumption ():
752824 """Test that run_live skips send_history when resuming a session."""
@@ -1390,7 +1462,7 @@ async def mock_receive_2():
13901462
13911463@pytest .mark .asyncio
13921464@pytest .mark .parametrize (
1393- " api_backend" ,
1465+ ' api_backend' ,
13941466 [
13951467 GoogleLLMVariant .GEMINI_API ,
13961468 GoogleLLMVariant .VERTEX_AI ,
@@ -1422,8 +1494,11 @@ async def mock_receive():
14221494 flow = BaseLlmFlowForTesting ()
14231495
14241496 with mock .patch .object (flow , '_send_to_model' , new_callable = AsyncMock ):
1497+
14251498 async def mock_preprocess (ctx , req ):
1426- req .contents = [types .Content (parts = [types .Part .from_text (text = 'history' )])]
1499+ req .contents = [
1500+ types .Content (parts = [types .Part .from_text (text = 'history' )])
1501+ ]
14271502 yield Event (id = Event .new_id (), author = 'test' )
14281503
14291504 with mock .patch .object (
@@ -1467,7 +1542,9 @@ async def test_run_live_respects_explicit_initial_history_in_client_content_fals
14671542 )
14681543 invocation_context .live_request_queue = LiveRequestQueue ()
14691544 run_config = RunConfig (
1470- history_config = types .HistoryConfig (initial_history_in_client_content = False )
1545+ history_config = types .HistoryConfig (
1546+ initial_history_in_client_content = False
1547+ )
14711548 )
14721549 invocation_context .run_config = run_config
14731550
@@ -1476,6 +1553,7 @@ async def test_run_live_respects_explicit_initial_history_in_client_content_fals
14761553 async def mock_preprocess (ctx , req ):
14771554 req .contents = [types .Content (parts = [types .Part .from_text (text = 'history' )])]
14781555 from google .adk .flows .llm_flows .basic import _build_basic_request
1556+
14791557 _build_basic_request (ctx , req )
14801558 yield Event (id = Event .new_id (), author = 'test' )
14811559
@@ -1509,5 +1587,7 @@ async def mock_receive():
15091587 assert mock_connect .call_count == 1
15101588 call_req = mock_connect .call_args [0 ][0 ]
15111589 assert call_req .live_connect_config .history_config is not None
1512- assert call_req .live_connect_config .history_config .initial_history_in_client_content is False
1513-
1590+ assert (
1591+ call_req .live_connect_config .history_config .initial_history_in_client_content
1592+ is False
1593+ )
0 commit comments