Skip to content

Commit a8b774c

Browse files
committed
fix(flows): Reset reconnect attempts on connection success
Reset the reconnect attempt counter immediately when a live connection is established successfully, rather than waiting for the first event to be received. This prevents reconnection failures if the connection succeeds but no events are received before a transient disconnect occurs. Change-Id: I9a8002558525fad5715b9a0ac5cf2634a92b923f
1 parent fafafb3 commit a8b774c

2 files changed

Lines changed: 52 additions & 2 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ async def run_live(
556556
invocation_context.agent.name,
557557
)
558558
async with llm.connect(llm_request) as llm_connection:
559+
# Reset attempt counter on successful connection.
560+
attempt = 1
559561
# Skip sending history if we are resuming a session. The server
560562
# already has the state associated with the resumption handle.
561563
if (
@@ -585,8 +587,6 @@ async def run_live(
585587
)
586588
) as agen:
587589
async for event in agen:
588-
# Reset attempt counter on successful communication.
589-
attempt = 1
590590
# Empty event means the queue is closed.
591591
if not event:
592592
break

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,3 +1402,53 @@ async def mock_receive():
14021402
call_req.live_connect_config.history_config.initial_history_in_client_content
14031403
is False
14041404
)
1405+
1406+
1407+
@pytest.mark.asyncio
1408+
async def test_run_live_reset_attempt_on_connection_success():
1409+
"""Test that attempt counter is reset immediately when connection succeeds."""
1410+
# Arrange
1411+
real_model = Gemini()
1412+
connection_cnt = 0
1413+
1414+
class StopTestError(Exception):
1415+
pass
1416+
1417+
async def mock_connect_impl(*args, **kwargs):
1418+
nonlocal connection_cnt
1419+
connection_cnt += 1
1420+
conn = mock.AsyncMock()
1421+
1422+
async def mock_receive():
1423+
if connection_cnt < 8:
1424+
# Raise ConnectionClosed immediately without yielding any events
1425+
raise ConnectionClosed(None, None)
1426+
else:
1427+
raise StopTestError('success')
1428+
1429+
conn.receive = mock.Mock(side_effect=mock_receive)
1430+
return conn
1431+
1432+
agent = Agent(name='test_agent', model=real_model)
1433+
invocation_context = await testing_utils.create_invocation_context(
1434+
agent=agent
1435+
)
1436+
invocation_context.live_request_queue = LiveRequestQueue()
1437+
# Pre-populate resumption handle so the flow attempts to reconnect
1438+
invocation_context.live_session_resumption_handle = 'test_handle'
1439+
1440+
flow = BaseLlmFlowForTesting()
1441+
1442+
# Act & Assert
1443+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1444+
with mock.patch(
1445+
'google.adk.models.google_llm.Gemini.connect'
1446+
) as mock_connect:
1447+
mock_connect.return_value.__aenter__.side_effect = mock_connect_impl
1448+
try:
1449+
async for _ in flow.run_live(invocation_context):
1450+
pass
1451+
except StopTestError:
1452+
pass
1453+
1454+
assert mock_connect.call_count == 8

0 commit comments

Comments
 (0)