Skip to content

Commit d6a117f

Browse files
committed
fix: preserve nontransparent live resumption
1 parent 4006fe4 commit d6a117f

2 files changed

Lines changed: 79 additions & 3 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,15 @@ async def run_live(
532532
attempt += 1
533533
if not llm_request.live_connect_config:
534534
llm_request.live_connect_config = types.LiveConnectConfig()
535-
if not llm_request.live_connect_config.session_resumption:
535+
session_resumption = (
536+
llm_request.live_connect_config.session_resumption
537+
)
538+
if not session_resumption:
539+
session_resumption = types.SessionResumptionConfig()
536540
llm_request.live_connect_config.session_resumption = (
537-
types.SessionResumptionConfig()
541+
session_resumption
538542
)
539-
llm_request.live_connect_config.session_resumption.handle = (
543+
session_resumption.handle = (
540544
invocation_context.live_session_resumption_handle
541545
)
542546

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,78 @@ async def mock_receive_2():
747747
assert invocation_context.live_session_resumption_handle == 'test_handle'
748748

749749

750+
@pytest.mark.asyncio
751+
async def test_run_live_reconnect_preserves_nontransparent_resumption():
752+
"""Test that reconnect does not force transparent resumption."""
753+
from google.adk.agents.live_request_queue import LiveRequestQueue
754+
from websockets.exceptions import ConnectionClosed
755+
756+
real_model = Gemini()
757+
mock_connection = mock.AsyncMock()
758+
759+
async def mock_receive():
760+
yield LlmResponse(
761+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
762+
new_handle='test_handle'
763+
)
764+
)
765+
raise ConnectionClosed(None, None)
766+
767+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
768+
769+
agent = Agent(name='test_agent', model=real_model)
770+
invocation_context = await testing_utils.create_invocation_context(
771+
agent=agent
772+
)
773+
invocation_context.live_request_queue = LiveRequestQueue()
774+
775+
flow = BaseLlmFlowForTesting()
776+
777+
async def mock_preprocess(ctx, req):
778+
req.live_connect_config.session_resumption = types.SessionResumptionConfig(
779+
transparent=False
780+
)
781+
if False:
782+
yield
783+
784+
with mock.patch.object(
785+
flow, '_preprocess_async', side_effect=mock_preprocess
786+
):
787+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
788+
mock_connection_2 = mock.AsyncMock()
789+
790+
class StopError(Exception):
791+
pass
792+
793+
async def mock_receive_2():
794+
yield LlmResponse(
795+
content=types.Content(parts=[types.Part.from_text(text='hi')])
796+
)
797+
raise StopError('stop')
798+
799+
mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)
800+
801+
mock_aenter = mock.AsyncMock()
802+
mock_aenter.side_effect = [mock_connection, mock_connection_2]
803+
804+
with mock.patch(
805+
'google.adk.models.google_llm.Gemini.connect'
806+
) as mock_connect:
807+
mock_connect.return_value.__aenter__ = mock_aenter
808+
809+
try:
810+
async for _ in flow.run_live(invocation_context):
811+
pass
812+
except StopError:
813+
pass
814+
815+
reconnect_request = mock_connect.call_args_list[1].args[0]
816+
assert (
817+
reconnect_request.live_connect_config.session_resumption.transparent
818+
is False
819+
)
820+
821+
750822
@pytest.mark.asyncio
751823
async def test_run_live_skips_send_history_on_resumption():
752824
"""Test that run_live skips send_history when resuming a session."""

0 commit comments

Comments
 (0)