2424from google .adk .events .event import Event
2525from google .adk .flows .llm_flows .base_llm_flow import _handle_after_model_callback
2626from google .adk .flows .llm_flows .base_llm_flow import BaseLlmFlow
27+ from google .adk .models .base_llm_connection import BaseLlmConnection
2728from google .adk .models .google_llm import Gemini
2829from google .adk .models .llm_request import LlmRequest
2930from google .adk .models .llm_response import LlmResponse
@@ -915,21 +916,22 @@ async def test_run_live_reconnect_limit():
915916 async def mock_connect_impl (* args , ** kwargs ):
916917 nonlocal connection_cnt
917918 connection_cnt += 1
918- conn = mock .AsyncMock ()
919+ if connection_cnt > 1 :
920+ raise ConnectionClosed (None , None )
921+
922+ conn = mock .create_autospec (BaseLlmConnection , instance = True )
919923
920924 async def mock_receive ():
921- if connection_cnt == 1 :
922- # Yield handle only on the first connection.
923- yield LlmResponse (
924- live_session_resumption_update = types .LiveServerSessionResumptionUpdate (
925- new_handle = 'test_handle'
926- ),
927- turn_complete = True ,
928- )
925+ yield LlmResponse (
926+ live_session_resumption_update = types .LiveServerSessionResumptionUpdate (
927+ new_handle = 'test_handle'
928+ ),
929+ turn_complete = True ,
930+ )
929931 # All subsequent receives (and all receives on later connections) fail.
930932 raise ConnectionClosed (None , None )
931933
932- conn .receive = mock . Mock ( side_effect = mock_receive )
934+ conn .receive . side_effect = mock_receive
933935 return conn
934936
935937 agent = Agent (name = 'test_agent' , model = real_model )
@@ -961,7 +963,7 @@ async def mock_receive():
961963
962964@pytest .mark .asyncio
963965async def test_run_live_reconnect_reset_attempt ():
964- """Test that attempt counter is reset on successful communication ."""
966+ """Test that attempt counter is reset on successful connection establishment ."""
965967 from google .adk .flows .llm_flows .base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS
966968
967969 real_model = Gemini ()
@@ -971,22 +973,28 @@ async def test_run_live_reconnect_reset_attempt():
971973 async def mock_connect_impl (* args , ** kwargs ):
972974 nonlocal connection_cnt
973975 connection_cnt += 1
974- conn = mock .AsyncMock ()
976+ # Establish connection successfully on attempts 1, 2, and 5
977+ if connection_cnt in (1 , 2 , 5 ):
978+ conn = mock .create_autospec (BaseLlmConnection , instance = True )
975979
976- async def mock_receive ():
977- if connection_cnt <= 2 :
978- # Yield handle on the first two connections.
979- yield LlmResponse (
980- live_session_resumption_update = types .LiveServerSessionResumptionUpdate (
981- new_handle = 'test_handle'
982- ),
983- turn_complete = True ,
984- )
985- # All subsequent receives fail.
986- raise ConnectionClosed (None , None )
980+ async def mock_receive ():
981+ if connection_cnt == 1 :
982+ yield LlmResponse (
983+ live_session_resumption_update = types .LiveServerSessionResumptionUpdate (
984+ new_handle = 'test_handle'
985+ ),
986+ turn_complete = True ,
987+ )
988+ else :
989+ if False :
990+ yield
991+ raise ConnectionClosed (None , None )
987992
988- conn .receive = mock .Mock (side_effect = mock_receive )
989- return conn
993+ conn .receive .side_effect = mock_receive
994+ return conn
995+ else :
996+ # Failed connection establishments on other attempts
997+ raise ConnectionClosed (None , None )
990998
991999 agent = Agent (name = 'test_agent' , model = real_model )
9921000 invocation_context = await testing_utils .create_invocation_context (
@@ -1008,9 +1016,13 @@ async def mock_receive():
10081016 async for _ in flow .run_live (invocation_context ):
10091017 pass
10101018
1011- # We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts
1012- # Total calls = 2 + 5 = 7
1013- assert mock_connect .call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2
1019+ # Connection 1: succeeds (resets to 1), yields handle, receive raises ConnectionClosed.
1020+ # Connection 2: succeeds (resets to 1), receive raises ConnectionClosed.
1021+ # Connection 3: fails (attempt becomes 2)
1022+ # Connection 4: fails (attempt becomes 3)
1023+ # Connection 5: succeeds (resets to 1), receive raises ConnectionClosed.
1024+ # Connection 6-10: fail. Connection 10 has attempt = 6 > DEFAULT_MAX_RECONNECT_ATTEMPTS (5), so raises and terminates.
1025+ assert mock_connect .call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 5
10141026
10151027
10161028@pytest .mark .asyncio
0 commit comments