Skip to content

Commit 5ad1942

Browse files
committed
fix(flows): preserve transparent config on live session reconnect
When session resumption triggers a reconnect, ADK previously forced transparent=True on the configuration. This caused a ValueError when using the Gemini API backend, as the google-genai SDK explicitly rejects transparent=True. This change ensures we only set transparent=True when creating a fresh SessionResumptionConfig, preserving the user's configuration if provided. Close #5675 Change-Id: Ifc75506f347655c95ee4194c74bae64b479c744a
1 parent afb0a64 commit 5ad1942

2 files changed

Lines changed: 158 additions & 13 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from ...auth.auth_tool import AuthConfig
4040
from ...events.event import Event
4141
from ...models.base_llm_connection import BaseLlmConnection
42+
from ...models.google_llm import Gemini
43+
from ...models.google_llm import GoogleLLMVariant
4244
from ...models.llm_request import LlmRequest
4345
from ...models.llm_response import LlmResponse
4446
from ...telemetry import tracing
@@ -517,7 +519,18 @@ async def run_live(
517519
llm_request.live_connect_config.session_resumption.handle = (
518520
invocation_context.live_session_resumption_handle
519521
)
520-
llm_request.live_connect_config.session_resumption.transparent = True
522+
523+
# Only set transparent=True for Vertex AI backend, as the Gemini API
524+
# backend explicitly rejects it.
525+
if (
526+
isinstance(llm, Gemini)
527+
and llm._api_backend == GoogleLLMVariant.VERTEX_AI # pylint: disable=protected-access
528+
):
529+
session_resumption = (
530+
llm_request.live_connect_config.session_resumption
531+
)
532+
if session_resumption.transparent is None:
533+
session_resumption.transparent = True
521534

522535
logger.info(
523536
'Establishing live connection for agent: %s',

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 144 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest import mock
1818
from unittest.mock import AsyncMock
1919

20+
from google.adk.agents.live_request_queue import LiveRequestQueue
2021
from google.adk.agents.llm_agent import Agent
2122
from google.adk.agents.run_config import RunConfig
2223
from google.adk.events.event import Event
@@ -30,6 +31,7 @@
3031
from google.adk.tools.google_search_tool import GoogleSearchTool
3132
from google.genai import types
3233
import pytest
34+
from websockets.exceptions import ConnectionClosed
3335

3436
from ... import testing_utils
3537

@@ -490,8 +492,6 @@ async def call(self, **kwargs):
490492
@pytest.mark.asyncio
491493
async def test_run_live_reconnects_on_connection_closed():
492494
"""Test that run_live reconnects when ConnectionClosed occurs."""
493-
from google.adk.agents.live_request_queue import LiveRequestQueue
494-
from websockets.exceptions import ConnectionClosed
495495

496496
real_model = Gemini()
497497
mock_connection = mock.AsyncMock()
@@ -558,7 +558,6 @@ async def mock_receive_2():
558558
@pytest.mark.asyncio
559559
async def test_run_live_reconnects_on_api_error():
560560
"""Test that run_live reconnects when APIError occurs."""
561-
from google.adk.agents.live_request_queue import LiveRequestQueue
562561
from google.genai.errors import APIError
563562

564563
real_model = Gemini()
@@ -626,7 +625,6 @@ async def mock_receive_2():
626625
@pytest.mark.asyncio
627626
async def test_run_live_skips_send_history_on_resumption():
628627
"""Test that run_live skips send_history when resuming a session."""
629-
from google.adk.agents.live_request_queue import LiveRequestQueue
630628

631629
real_model = Gemini()
632630
mock_connection = mock.AsyncMock()
@@ -684,7 +682,6 @@ async def mock_receive():
684682
@pytest.mark.asyncio
685683
async def test_live_session_resumption_go_away():
686684
"""Test that go_away triggers reconnection."""
687-
from google.adk.agents.live_request_queue import LiveRequestQueue
688685

689686
real_model = Gemini()
690687
mock_connection = mock.AsyncMock()
@@ -743,8 +740,6 @@ async def mock_receive_2():
743740
@pytest.mark.asyncio
744741
async def test_run_live_no_reconnect_without_handle():
745742
"""Test that run_live does not reconnect when handle is missing."""
746-
from google.adk.agents.live_request_queue import LiveRequestQueue
747-
from websockets.exceptions import ConnectionClosed
748743

749744
real_model = Gemini()
750745
mock_connection = mock.AsyncMock()
@@ -786,8 +781,6 @@ async def mock_receive():
786781
@pytest.mark.asyncio
787782
async def test_run_live_reconnect_limit():
788783
"""Test that run_live stops reconnecting after 5 attempts."""
789-
from google.adk.agents.live_request_queue import LiveRequestQueue
790-
from websockets.exceptions import ConnectionClosed
791784

792785
real_model = Gemini()
793786

@@ -843,9 +836,7 @@ async def mock_receive():
843836
@pytest.mark.asyncio
844837
async def test_run_live_reconnect_reset_attempt():
845838
"""Test that attempt counter is reset on successful communication."""
846-
from google.adk.agents.live_request_queue import LiveRequestQueue
847839
from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS
848-
from websockets.exceptions import ConnectionClosed
849840

850841
real_model = Gemini()
851842

@@ -987,7 +978,6 @@ async def mock_receive():
987978
@pytest.mark.asyncio
988979
async def test_run_live_clears_resumption_handle_on_transfer():
989980
"""Test that run_live clears session resumption handles when transferring to another agent."""
990-
from google.adk.agents.live_request_queue import LiveRequestQueue
991981

992982
agent = Agent(name='test_agent')
993983
invocation_context = await testing_utils.create_invocation_context(
@@ -1129,3 +1119,145 @@ async def test_postprocess_async_yields_grounding_metadata_only():
11291119

11301120
assert len(events) == 1
11311121
assert events[0].grounding_metadata == grounding_metadata
1122+
1123+
1124+
@pytest.mark.asyncio
1125+
async def test_run_live_reconnect_does_not_set_transparent():
1126+
"""Test that run_live reconnect does not set transparent=True."""
1127+
1128+
real_model = Gemini()
1129+
mock_connection = mock.AsyncMock()
1130+
1131+
async def mock_receive():
1132+
yield LlmResponse(
1133+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
1134+
new_handle='test_handle'
1135+
)
1136+
)
1137+
raise ConnectionClosed(None, None)
1138+
1139+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
1140+
1141+
agent = Agent(name='test_agent', model=real_model)
1142+
invocation_context = await testing_utils.create_invocation_context(
1143+
agent=agent
1144+
)
1145+
invocation_context.live_request_queue = LiveRequestQueue()
1146+
invocation_context.run_config = RunConfig()
1147+
1148+
flow = BaseLlmFlowForTesting()
1149+
1150+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1151+
1152+
async def mock_preprocess(ctx, req):
1153+
req.live_connect_config.session_resumption = (
1154+
ctx.run_config.session_resumption
1155+
)
1156+
yield Event(id=Event.new_id(), author='test')
1157+
1158+
with mock.patch.object(
1159+
flow, '_preprocess_async', side_effect=mock_preprocess
1160+
):
1161+
mock_connection_2 = mock.AsyncMock()
1162+
1163+
class StopTestError(Exception):
1164+
pass
1165+
1166+
async def mock_receive_2():
1167+
yield LlmResponse(
1168+
content=types.Content(parts=[types.Part.from_text(text='hi')])
1169+
)
1170+
raise StopTestError('stop')
1171+
1172+
mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)
1173+
1174+
mock_aenter = mock.AsyncMock()
1175+
mock_aenter.side_effect = [mock_connection, mock_connection_2]
1176+
1177+
with mock.patch(
1178+
'google.adk.models.google_llm.Gemini.connect'
1179+
) as mock_connect:
1180+
mock_connect.return_value.__aenter__ = mock_aenter
1181+
1182+
try:
1183+
async for _ in flow.run_live(invocation_context):
1184+
pass
1185+
except StopTestError:
1186+
pass
1187+
1188+
assert mock_connect.call_count == 2
1189+
second_call_req = mock_connect.call_args_list[1][0][0]
1190+
session_resump = second_call_req.live_connect_config.session_resumption
1191+
assert session_resump.transparent is None
1192+
1193+
1194+
@pytest.mark.asyncio
1195+
async def test_run_live_reconnect_sets_transparent_for_vertex():
1196+
"""Test that run_live reconnect sets transparent=True for vertex backend."""
1197+
1198+
real_model = Gemini(
1199+
model='projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp'
1200+
)
1201+
mock_connection = mock.AsyncMock()
1202+
1203+
async def mock_receive():
1204+
yield LlmResponse(
1205+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
1206+
new_handle='test_handle'
1207+
)
1208+
)
1209+
raise ConnectionClosed(None, None)
1210+
1211+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
1212+
1213+
agent = Agent(name='test_agent', model=real_model)
1214+
invocation_context = await testing_utils.create_invocation_context(
1215+
agent=agent
1216+
)
1217+
invocation_context.live_request_queue = LiveRequestQueue()
1218+
invocation_context.run_config = RunConfig()
1219+
1220+
flow = BaseLlmFlowForTesting()
1221+
1222+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1223+
1224+
async def mock_preprocess(ctx, req):
1225+
req.live_connect_config.session_resumption = (
1226+
ctx.run_config.session_resumption
1227+
)
1228+
yield Event(id=Event.new_id(), author='test')
1229+
1230+
with mock.patch.object(
1231+
flow, '_preprocess_async', side_effect=mock_preprocess
1232+
):
1233+
mock_connection_2 = mock.AsyncMock()
1234+
1235+
class StopTestError(Exception):
1236+
pass
1237+
1238+
async def mock_receive_2():
1239+
yield LlmResponse(
1240+
content=types.Content(parts=[types.Part.from_text(text='hi')])
1241+
)
1242+
raise StopTestError('stop')
1243+
1244+
mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)
1245+
1246+
mock_aenter = mock.AsyncMock()
1247+
mock_aenter.side_effect = [mock_connection, mock_connection_2]
1248+
1249+
with mock.patch(
1250+
'google.adk.models.google_llm.Gemini.connect'
1251+
) as mock_connect:
1252+
mock_connect.return_value.__aenter__ = mock_aenter
1253+
1254+
try:
1255+
async for _ in flow.run_live(invocation_context):
1256+
pass
1257+
except StopTestError:
1258+
pass
1259+
1260+
assert mock_connect.call_count == 2
1261+
second_call_req = mock_connect.call_args_list[1][0][0]
1262+
session_resump = second_call_req.live_connect_config.session_resumption
1263+
assert session_resump.transparent

0 commit comments

Comments
 (0)