Skip to content

Commit 8fb354f

Browse files
committed
fix: preserve nontransparent live resumption
1 parent 9670ce2 commit 8fb354f

2 files changed

Lines changed: 79 additions & 3 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,15 @@ async def run_live(
533533
attempt += 1
534534
if not llm_request.live_connect_config:
535535
llm_request.live_connect_config = types.LiveConnectConfig()
536-
if not llm_request.live_connect_config.session_resumption:
536+
session_resumption = (
537+
llm_request.live_connect_config.session_resumption
538+
)
539+
if not session_resumption:
540+
session_resumption = types.SessionResumptionConfig()
537541
llm_request.live_connect_config.session_resumption = (
538-
types.SessionResumptionConfig()
542+
session_resumption
539543
)
540-
llm_request.live_connect_config.session_resumption.handle = (
544+
session_resumption.handle = (
541545
invocation_context.live_session_resumption_handle
542546
)
543547

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,78 @@ async def mock_receive_2():
748748
assert invocation_context.live_session_resumption_handle == 'test_handle'
749749

750750

751+
@pytest.mark.asyncio
752+
async def test_run_live_reconnect_preserves_nontransparent_resumption():
753+
"""Test that reconnect does not force transparent resumption."""
754+
from google.adk.agents.live_request_queue import LiveRequestQueue
755+
from websockets.exceptions import ConnectionClosed
756+
757+
real_model = Gemini()
758+
mock_connection = mock.AsyncMock()
759+
760+
async def mock_receive():
761+
yield LlmResponse(
762+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
763+
new_handle='test_handle'
764+
)
765+
)
766+
raise ConnectionClosed(None, None)
767+
768+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
769+
770+
agent = Agent(name='test_agent', model=real_model)
771+
invocation_context = await testing_utils.create_invocation_context(
772+
agent=agent
773+
)
774+
invocation_context.live_request_queue = LiveRequestQueue()
775+
776+
flow = BaseLlmFlowForTesting()
777+
778+
async def mock_preprocess(ctx, req):
779+
req.live_connect_config.session_resumption = types.SessionResumptionConfig(
780+
transparent=False
781+
)
782+
if False:
783+
yield
784+
785+
with mock.patch.object(
786+
flow, '_preprocess_async', side_effect=mock_preprocess
787+
):
788+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
789+
mock_connection_2 = mock.AsyncMock()
790+
791+
class StopError(Exception):
792+
pass
793+
794+
async def mock_receive_2():
795+
yield LlmResponse(
796+
content=types.Content(parts=[types.Part.from_text(text='hi')])
797+
)
798+
raise StopError('stop')
799+
800+
mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)
801+
802+
mock_aenter = mock.AsyncMock()
803+
mock_aenter.side_effect = [mock_connection, mock_connection_2]
804+
805+
with mock.patch(
806+
'google.adk.models.google_llm.Gemini.connect'
807+
) as mock_connect:
808+
mock_connect.return_value.__aenter__ = mock_aenter
809+
810+
try:
811+
async for _ in flow.run_live(invocation_context):
812+
pass
813+
except StopError:
814+
pass
815+
816+
reconnect_request = mock_connect.call_args_list[1].args[0]
817+
assert (
818+
reconnect_request.live_connect_config.session_resumption.transparent
819+
is False
820+
)
821+
822+
751823
@pytest.mark.asyncio
752824
async def test_run_live_skips_send_history_on_resumption():
753825
"""Test that run_live skips send_history when resuming a session."""

0 commit comments

Comments
 (0)