diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 20093237d3..39a66244ca 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -556,6 +556,8 @@ async def run_live( invocation_context.agent.name, ) async with llm.connect(llm_request) as llm_connection: + # Reset attempt counter on successful connection. + attempt = 1 # Skip sending history if we are resuming a session. The server # already has the state associated with the resumption handle. if ( @@ -585,8 +587,6 @@ async def run_live( ) ) as agen: async for event in agen: - # Reset attempt counter on successful communication. - attempt = 1 # Empty event means the queue is closed. if not event: break diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 7de544b4f1..b5c3f1a612 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -790,17 +790,17 @@ async def test_run_live_reconnect_limit(): async def mock_connect_impl(*args, **kwargs): nonlocal connection_cnt connection_cnt += 1 + if connection_cnt > 1: + raise ConnectionClosed(None, None) conn = mock.AsyncMock() async def mock_receive(): - if connection_cnt == 1: - # Yield handle only on the first connection. - yield LlmResponse( - live_session_resumption_update=types.LiveServerSessionResumptionUpdate( - new_handle='test_handle' - ), - turn_complete=True, - ) + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ), + turn_complete=True, + ) # All subsequent receives (and all receives on later connections) fail. raise ConnectionClosed(None, None) @@ -836,7 +836,7 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_reconnect_reset_attempt(): - """Test that attempt counter is reset on successful communication.""" + """Test that attempt counter is reset on successful connection establishment.""" from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS real_model = Gemini() @@ -846,22 +846,28 @@ async def test_run_live_reconnect_reset_attempt(): async def mock_connect_impl(*args, **kwargs): nonlocal connection_cnt connection_cnt += 1 - conn = mock.AsyncMock() + # Establish connection successfully on attempts 1, 2, and 5 + if connection_cnt in (1, 2, 5): + conn = mock.AsyncMock() - async def mock_receive(): - if connection_cnt <= 2: - # Yield handle on the first two connections. - yield LlmResponse( - live_session_resumption_update=types.LiveServerSessionResumptionUpdate( - new_handle='test_handle' - ), - turn_complete=True, - ) - # All subsequent receives fail. - raise ConnectionClosed(None, None) + async def mock_receive(): + if connection_cnt == 1: + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ), + turn_complete=True, + ) + else: + if False: + yield + raise ConnectionClosed(None, None) - conn.receive = mock.Mock(side_effect=mock_receive) - return conn + conn.receive = mock.Mock(side_effect=mock_receive) + return conn + else: + # Failed connection establishments on other attempts + raise ConnectionClosed(None, None) agent = Agent(name='test_agent', model=real_model) invocation_context = await testing_utils.create_invocation_context( @@ -883,9 +889,13 @@ async def mock_receive(): async for _ in flow.run_live(invocation_context): pass - # We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts - # Total calls = 2 + 5 = 7 - assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2 + # Connection 1: succeeds (resets to 1), yields handle, receive raises ConnectionClosed. + # Connection 2: succeeds (resets to 1), receive raises ConnectionClosed. + # Connection 3: fails (attempt becomes 2) + # Connection 4: fails (attempt becomes 3) + # Connection 5: succeeds (resets to 1), receive raises ConnectionClosed. + # Connection 6-10: fail. Connection 10 has attempt = 6 > DEFAULT_MAX_RECONNECT_ATTEMPTS (5), so raises and terminates. + assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 5 @pytest.mark.asyncio