Skip to content

Commit ca8baf1

Browse files
wukathcopybara-github
authored andcommitted
fix: Reset retry attempt counter on successful connection
When an idle live session connection is resumed and subsequently dropped, the retry counter was not being reset since no model messages were actively received. Resetting the retry counter immediately upon successful connection handshake prevents reconnect starvation. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 929601570
1 parent e2676fc commit ca8baf1

2 files changed

Lines changed: 43 additions & 30 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ async def run_live(
580580
invocation_context.agent.name,
581581
)
582582
async with llm.connect(llm_request) as llm_connection:
583+
# Reset retry count to allow the maximum reconnect attempts for
584+
# subsequent connection drops.
585+
attempt = 1
583586
# Skip sending history if we are resuming a session. The server
584587
# already has the state associated with the resumption handle.
585588
if (
@@ -609,8 +612,6 @@ async def run_live(
609612
)
610613
) as agen:
611614
async for event in agen:
612-
# Reset attempt counter on successful communication.
613-
attempt = 1
614615
# Empty event means the queue is closed.
615616
if not event:
616617
break

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.adk.events.event import Event
2525
from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback
2626
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
27+
from google.adk.models.base_llm_connection import BaseLlmConnection
2728
from google.adk.models.google_llm import Gemini
2829
from google.adk.models.llm_request import LlmRequest
2930
from google.adk.models.llm_response import LlmResponse
@@ -915,21 +916,22 @@ async def test_run_live_reconnect_limit():
915916
async def mock_connect_impl(*args, **kwargs):
916917
nonlocal connection_cnt
917918
connection_cnt += 1
918-
conn = mock.AsyncMock()
919+
if connection_cnt > 1:
920+
raise ConnectionClosed(None, None)
921+
922+
conn = mock.create_autospec(BaseLlmConnection, instance=True)
919923

920924
async def mock_receive():
921-
if connection_cnt == 1:
922-
# Yield handle only on the first connection.
923-
yield LlmResponse(
924-
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
925-
new_handle='test_handle'
926-
),
927-
turn_complete=True,
928-
)
925+
yield LlmResponse(
926+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
927+
new_handle='test_handle'
928+
),
929+
turn_complete=True,
930+
)
929931
# All subsequent receives (and all receives on later connections) fail.
930932
raise ConnectionClosed(None, None)
931933

932-
conn.receive = mock.Mock(side_effect=mock_receive)
934+
conn.receive.side_effect = mock_receive
933935
return conn
934936

935937
agent = Agent(name='test_agent', model=real_model)
@@ -961,7 +963,7 @@ async def mock_receive():
961963

962964
@pytest.mark.asyncio
963965
async def test_run_live_reconnect_reset_attempt():
964-
"""Test that attempt counter is reset on successful communication."""
966+
"""Test that attempt counter is reset on successful connection establishment."""
965967
from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS
966968

967969
real_model = Gemini()
@@ -971,22 +973,28 @@ async def test_run_live_reconnect_reset_attempt():
971973
async def mock_connect_impl(*args, **kwargs):
972974
nonlocal connection_cnt
973975
connection_cnt += 1
974-
conn = mock.AsyncMock()
976+
# Establish connection successfully on attempts 1, 2, and 5
977+
if connection_cnt in (1, 2, 5):
978+
conn = mock.create_autospec(BaseLlmConnection, instance=True)
975979

976-
async def mock_receive():
977-
if connection_cnt <= 2:
978-
# Yield handle on the first two connections.
979-
yield LlmResponse(
980-
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
981-
new_handle='test_handle'
982-
),
983-
turn_complete=True,
984-
)
985-
# All subsequent receives fail.
986-
raise ConnectionClosed(None, None)
980+
async def mock_receive():
981+
if connection_cnt == 1:
982+
yield LlmResponse(
983+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
984+
new_handle='test_handle'
985+
),
986+
turn_complete=True,
987+
)
988+
else:
989+
if False:
990+
yield
991+
raise ConnectionClosed(None, None)
987992

988-
conn.receive = mock.Mock(side_effect=mock_receive)
989-
return conn
993+
conn.receive.side_effect = mock_receive
994+
return conn
995+
else:
996+
# Failed connection establishments on other attempts
997+
raise ConnectionClosed(None, None)
990998

991999
agent = Agent(name='test_agent', model=real_model)
9921000
invocation_context = await testing_utils.create_invocation_context(
@@ -1008,9 +1016,13 @@ async def mock_receive():
10081016
async for _ in flow.run_live(invocation_context):
10091017
pass
10101018

1011-
# We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts
1012-
# Total calls = 2 + 5 = 7
1013-
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2
1019+
# Connection 1: succeeds (resets to 1), yields handle, receive raises ConnectionClosed.
1020+
# Connection 2: succeeds (resets to 1), receive raises ConnectionClosed.
1021+
# Connection 3: fails (attempt becomes 2)
1022+
# Connection 4: fails (attempt becomes 3)
1023+
# Connection 5: succeeds (resets to 1), receive raises ConnectionClosed.
1024+
# Connection 6-10: fail. Connection 10 has attempt = 6 > DEFAULT_MAX_RECONNECT_ATTEMPTS (5), so raises and terminates.
1025+
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 5
10141026

10151027

10161028
@pytest.mark.asyncio

0 commit comments

Comments
 (0)