Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@ async def run_live(
invocation_context.agent.name,
)
async with llm.connect(llm_request) as llm_connection:
# Reset attempt counter on successful connection.
attempt = 1
# Skip sending history if we are resuming a session. The server
# already has the state associated with the resumption handle.
if (
Expand Down Expand Up @@ -585,8 +587,6 @@ async def run_live(
)
) as agen:
async for event in agen:
# Reset attempt counter on successful communication.
attempt = 1
# Empty event means the queue is closed.
if not event:
break
Expand Down
62 changes: 36 additions & 26 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,17 +790,17 @@ async def test_run_live_reconnect_limit():
async def mock_connect_impl(*args, **kwargs):
nonlocal connection_cnt
connection_cnt += 1
if connection_cnt > 1:
raise ConnectionClosed(None, None)
conn = mock.AsyncMock()

async def mock_receive():
if connection_cnt == 1:
# Yield handle only on the first connection.
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
),
turn_complete=True,
)
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
),
turn_complete=True,
)
# All subsequent receives (and all receives on later connections) fail.
raise ConnectionClosed(None, None)

Expand Down Expand Up @@ -836,7 +836,7 @@ async def mock_receive():

@pytest.mark.asyncio
async def test_run_live_reconnect_reset_attempt():
"""Test that attempt counter is reset on successful communication."""
"""Test that attempt counter is reset on successful connection establishment."""
from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS

real_model = Gemini()
Expand All @@ -846,22 +846,28 @@ async def test_run_live_reconnect_reset_attempt():
async def mock_connect_impl(*args, **kwargs):
nonlocal connection_cnt
connection_cnt += 1
conn = mock.AsyncMock()
# Establish connection successfully on attempts 1, 2, and 5
if connection_cnt in (1, 2, 5):
conn = mock.AsyncMock()

async def mock_receive():
if connection_cnt <= 2:
# Yield handle on the first two connections.
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
),
turn_complete=True,
)
# All subsequent receives fail.
raise ConnectionClosed(None, None)
async def mock_receive():
if connection_cnt == 1:
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
),
turn_complete=True,
)
else:
if False:
yield
raise ConnectionClosed(None, None)

conn.receive = mock.Mock(side_effect=mock_receive)
return conn
conn.receive = mock.Mock(side_effect=mock_receive)
return conn
else:
# Failed connection establishments on other attempts
raise ConnectionClosed(None, None)

agent = Agent(name='test_agent', model=real_model)
invocation_context = await testing_utils.create_invocation_context(
Expand All @@ -883,9 +889,13 @@ async def mock_receive():
async for _ in flow.run_live(invocation_context):
pass

# We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts
# Total calls = 2 + 5 = 7
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2
# Connection 1: succeeds (resets to 1), yields handle, receive raises ConnectionClosed.
# Connection 2: succeeds (resets to 1), receive raises ConnectionClosed.
# Connection 3: fails (attempt becomes 2)
# Connection 4: fails (attempt becomes 3)
# Connection 5: succeeds (resets to 1), receive raises ConnectionClosed.
# Connection 6-10: fail. Connection 10 has attempt = 6 > DEFAULT_MAX_RECONNECT_ATTEMPTS (5), so raises and terminates.
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 5


@pytest.mark.asyncio
Expand Down
Loading