Skip to content

Commit 16d496b

Browse files
committed
fix(flows): Reset reconnect attempts on connection success
Reset the reconnect attempt counter immediately when a live connection is established successfully, rather than waiting for the first event to be received. This prevents reconnection failures if the connection succeeds but no events are received before a transient disconnect occurs. Change-Id: I9a8002558525fad5715b9a0ac5cf2634a92b923f
1 parent fafafb3 commit 16d496b

2 files changed

Lines changed: 38 additions & 28 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ async def run_live(
556556
invocation_context.agent.name,
557557
)
558558
async with llm.connect(llm_request) as llm_connection:
559+
# Reset attempt counter on successful connection.
560+
attempt = 1
559561
# Skip sending history if we are resuming a session. The server
560562
# already has the state associated with the resumption handle.
561563
if (
@@ -585,8 +587,6 @@ async def run_live(
585587
)
586588
) as agen:
587589
async for event in agen:
588-
# Reset attempt counter on successful communication.
589-
attempt = 1
590590
# Empty event means the queue is closed.
591591
if not event:
592592
break

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -790,17 +790,17 @@ async def test_run_live_reconnect_limit():
790790
async def mock_connect_impl(*args, **kwargs):
791791
nonlocal connection_cnt
792792
connection_cnt += 1
793+
if connection_cnt > 1:
794+
raise ConnectionClosed(None, None)
793795
conn = mock.AsyncMock()
794796

795797
async def mock_receive():
796-
if connection_cnt == 1:
797-
# Yield handle only on the first connection.
798-
yield LlmResponse(
799-
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
800-
new_handle='test_handle'
801-
),
802-
turn_complete=True,
803-
)
798+
yield LlmResponse(
799+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
800+
new_handle='test_handle'
801+
),
802+
turn_complete=True,
803+
)
804804
# All subsequent receives (and all receives on later connections) fail.
805805
raise ConnectionClosed(None, None)
806806

@@ -836,7 +836,7 @@ async def mock_receive():
836836

837837
@pytest.mark.asyncio
838838
async def test_run_live_reconnect_reset_attempt():
839-
"""Test that attempt counter is reset on successful communication."""
839+
"""Test that attempt counter is reset on successful connection establishment."""
840840
from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS
841841

842842
real_model = Gemini()
@@ -846,22 +846,28 @@ async def test_run_live_reconnect_reset_attempt():
846846
async def mock_connect_impl(*args, **kwargs):
847847
nonlocal connection_cnt
848848
connection_cnt += 1
849-
conn = mock.AsyncMock()
849+
# Establish connection successfully on attempts 1, 2, and 5
850+
if connection_cnt in (1, 2, 5):
851+
conn = mock.AsyncMock()
850852

851-
async def mock_receive():
852-
if connection_cnt <= 2:
853-
# Yield handle on the first two connections.
854-
yield LlmResponse(
855-
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
856-
new_handle='test_handle'
857-
),
858-
turn_complete=True,
859-
)
860-
# All subsequent receives fail.
861-
raise ConnectionClosed(None, None)
853+
async def mock_receive():
854+
if connection_cnt == 1:
855+
yield LlmResponse(
856+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
857+
new_handle='test_handle'
858+
),
859+
turn_complete=True,
860+
)
861+
else:
862+
if False:
863+
yield
864+
raise ConnectionClosed(None, None)
862865

863-
conn.receive = mock.Mock(side_effect=mock_receive)
864-
return conn
866+
conn.receive = mock.Mock(side_effect=mock_receive)
867+
return conn
868+
else:
869+
# Failed connection establishments on other attempts
870+
raise ConnectionClosed(None, None)
865871

866872
agent = Agent(name='test_agent', model=real_model)
867873
invocation_context = await testing_utils.create_invocation_context(
@@ -883,9 +889,13 @@ async def mock_receive():
883889
async for _ in flow.run_live(invocation_context):
884890
pass
885891

886-
# We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts
887-
# Total calls = 2 + 5 = 7
888-
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2
892+
# Connection 1: succeeds (resets to 1), yields handle, receive raises ConnectionClosed.
893+
# Connection 2: succeeds (resets to 1), receive raises ConnectionClosed.
894+
# Connection 3: fails (attempt becomes 2)
895+
# Connection 4: fails (attempt becomes 3)
896+
# Connection 5: succeeds (resets to 1), receive raises ConnectionClosed.
897+
# Connection 6-10: fail. Connection 10 has attempt = 6 > DEFAULT_MAX_RECONNECT_ATTEMPTS (5), so raises and terminates.
898+
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 5
889899

890900

891901
@pytest.mark.asyncio

0 commit comments

Comments
 (0)