Skip to content

Commit 4d165ef

Browse files
wukathcopybara-github
authored andcommitted
fix: exit connection cleanly on expected GoAway signal in bidi streaming
Receive the GoAway signal from the Gemini Live API, set a flag on the InvocationContext indicating reconnection is requested, and exit the receive generator cleanly instead of raising a ConnectionClosed exception. This avoids throwing expected session-recycling exceptions into custom client wrappers, which helps prevent false alarms in custom client log monitors. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 937586604
1 parent 0b79f8d commit 4d165ef

2 files changed

Lines changed: 26 additions & 5 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@
5656
# Prefix used by toolset auth credential IDs
5757
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'
5858

59+
60+
class _ReconnectSentinel(Event):
61+
"""Internal sentinel event to signal a silent reconnection request."""
62+
63+
5964
if TYPE_CHECKING:
6065
from ...agents.llm_agent import LlmAgent
6166
from ...models.base_llm import BaseLlm
@@ -623,6 +628,7 @@ async def run_live(
623628
self._send_to_model(llm_connection, invocation_context)
624629
)
625630

631+
should_reconnect = False
626632
try:
627633
async with Aclosing(
628634
self._receive_from_model(
@@ -633,6 +639,9 @@ async def run_live(
633639
)
634640
) as agen:
635641
async for event in agen:
642+
if isinstance(event, _ReconnectSentinel):
643+
should_reconnect = True
644+
break
636645
# Empty event means the queue is closed.
637646
if not event:
638647
break
@@ -713,6 +722,9 @@ async def run_live(
713722
await send_task
714723
except asyncio.CancelledError:
715724
pass
725+
if should_reconnect:
726+
continue
727+
break
716728
except (ConnectionClosed, ConnectionClosedOK) as e:
717729
# If we have a session resumption handle, we attempt to reconnect.
718730
# This handle is updated dynamically during the session.
@@ -852,9 +864,9 @@ def get_author_for_event(llm_response: LlmResponse) -> str:
852864
if llm_response.go_away:
853865
logger.info(f'Received go away signal: {llm_response.go_away}')
854866
# The server signals that it will close the connection soon.
855-
# We proactively raise ConnectionClosed to trigger the reconnection
856-
# logic in run_live, which will use the latest session handle.
857-
raise ConnectionClosed(None, None)
867+
# We yield a sentinel event to request reconnection internally.
868+
yield _ReconnectSentinel()
869+
return
858870

859871
model_response_event = Event(
860872
id=Event.new_id(),

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.adk.agents.run_config import RunConfig
2525
from google.adk.events.event import Event
2626
from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback
27+
from google.adk.flows.llm_flows.base_llm_flow import _ReconnectSentinel
2728
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
2829
from google.adk.models.base_llm_connection import BaseLlmConnection
2930
from google.adk.models.google_llm import Gemini
@@ -939,15 +940,23 @@ async def mock_receive_2():
939940
) as mock_connect:
940941
mock_connect.return_value.__aenter__ = mock_aenter
941942

943+
yielded_events = []
942944
try:
943-
async for _ in flow.run_live(invocation_context):
944-
pass
945+
async for event in flow.run_live(invocation_context):
946+
yielded_events.append(event)
945947
except StopError:
946948
pass
947949

948950
# Verify that we attempted to connect twice (initial + reconnect after go_away).
949951
assert mock_connect.call_count == 2
950952

953+
# Verify that the internal _ReconnectSentinel is not leaked/yielded to the caller.
954+
assert not any(isinstance(e, _ReconnectSentinel) for e in yielded_events)
955+
956+
# Verify we yielded the expected response after reconnection.
957+
assert len(yielded_events) == 1
958+
assert yielded_events[0].content.parts[0].text == 'hi'
959+
951960

952961
@pytest.mark.asyncio
953962
async def test_run_live_no_reconnect_without_handle():

0 commit comments

Comments
 (0)