Skip to content

Commit 018065a

Browse files
committed
fix: preserve nontransparent live resumption
1 parent 61a3933 commit 018065a

2 files changed

Lines changed: 92 additions & 8 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,15 @@ async def run_live(
533533
attempt += 1
534534
if not llm_request.live_connect_config:
535535
llm_request.live_connect_config = types.LiveConnectConfig()
536-
if not llm_request.live_connect_config.session_resumption:
536+
session_resumption = (
537+
llm_request.live_connect_config.session_resumption
538+
)
539+
if not session_resumption:
540+
session_resumption = types.SessionResumptionConfig()
537541
llm_request.live_connect_config.session_resumption = (
538-
types.SessionResumptionConfig()
542+
session_resumption
539543
)
540-
llm_request.live_connect_config.session_resumption.handle = (
544+
session_resumption.handle = (
541545
invocation_context.live_session_resumption_handle
542546
)
543547

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,78 @@ async def mock_receive_2():
747747
assert invocation_context.live_session_resumption_handle == 'test_handle'
748748

749749

750+
@pytest.mark.asyncio
751+
async def test_run_live_reconnect_preserves_nontransparent_resumption():
752+
"""Test that reconnect does not force transparent resumption."""
753+
from google.adk.agents.live_request_queue import LiveRequestQueue
754+
from websockets.exceptions import ConnectionClosed
755+
756+
real_model = Gemini()
757+
mock_connection = mock.AsyncMock()
758+
759+
async def mock_receive():
760+
yield LlmResponse(
761+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
762+
new_handle='test_handle'
763+
)
764+
)
765+
raise ConnectionClosed(None, None)
766+
767+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
768+
769+
agent = Agent(name='test_agent', model=real_model)
770+
invocation_context = await testing_utils.create_invocation_context(
771+
agent=agent
772+
)
773+
invocation_context.live_request_queue = LiveRequestQueue()
774+
775+
flow = BaseLlmFlowForTesting()
776+
777+
async def mock_preprocess(ctx, req):
778+
req.live_connect_config.session_resumption = types.SessionResumptionConfig(
779+
transparent=False
780+
)
781+
if False:
782+
yield
783+
784+
with mock.patch.object(
785+
flow, '_preprocess_async', side_effect=mock_preprocess
786+
):
787+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
788+
mock_connection_2 = mock.AsyncMock()
789+
790+
class StopError(Exception):
791+
pass
792+
793+
async def mock_receive_2():
794+
yield LlmResponse(
795+
content=types.Content(parts=[types.Part.from_text(text='hi')])
796+
)
797+
raise StopError('stop')
798+
799+
mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)
800+
801+
mock_aenter = mock.AsyncMock()
802+
mock_aenter.side_effect = [mock_connection, mock_connection_2]
803+
804+
with mock.patch(
805+
'google.adk.models.google_llm.Gemini.connect'
806+
) as mock_connect:
807+
mock_connect.return_value.__aenter__ = mock_aenter
808+
809+
try:
810+
async for _ in flow.run_live(invocation_context):
811+
pass
812+
except StopError:
813+
pass
814+
815+
reconnect_request = mock_connect.call_args_list[1].args[0]
816+
assert (
817+
reconnect_request.live_connect_config.session_resumption.transparent
818+
is False
819+
)
820+
821+
750822
@pytest.mark.asyncio
751823
async def test_run_live_skips_send_history_on_resumption():
752824
"""Test that run_live skips send_history when resuming a session."""
@@ -1390,7 +1462,7 @@ async def mock_receive_2():
13901462

13911463
@pytest.mark.asyncio
13921464
@pytest.mark.parametrize(
1393-
"api_backend",
1465+
'api_backend',
13941466
[
13951467
GoogleLLMVariant.GEMINI_API,
13961468
GoogleLLMVariant.VERTEX_AI,
@@ -1422,8 +1494,11 @@ async def mock_receive():
14221494
flow = BaseLlmFlowForTesting()
14231495

14241496
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1497+
14251498
async def mock_preprocess(ctx, req):
1426-
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
1499+
req.contents = [
1500+
types.Content(parts=[types.Part.from_text(text='history')])
1501+
]
14271502
yield Event(id=Event.new_id(), author='test')
14281503

14291504
with mock.patch.object(
@@ -1467,7 +1542,9 @@ async def test_run_live_respects_explicit_initial_history_in_client_content_fals
14671542
)
14681543
invocation_context.live_request_queue = LiveRequestQueue()
14691544
run_config = RunConfig(
1470-
history_config=types.HistoryConfig(initial_history_in_client_content=False)
1545+
history_config=types.HistoryConfig(
1546+
initial_history_in_client_content=False
1547+
)
14711548
)
14721549
invocation_context.run_config = run_config
14731550

@@ -1476,6 +1553,7 @@ async def test_run_live_respects_explicit_initial_history_in_client_content_fals
14761553
async def mock_preprocess(ctx, req):
14771554
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
14781555
from google.adk.flows.llm_flows.basic import _build_basic_request
1556+
14791557
_build_basic_request(ctx, req)
14801558
yield Event(id=Event.new_id(), author='test')
14811559

@@ -1509,5 +1587,7 @@ async def mock_receive():
15091587
assert mock_connect.call_count == 1
15101588
call_req = mock_connect.call_args[0][0]
15111589
assert call_req.live_connect_config.history_config is not None
1512-
assert call_req.live_connect_config.history_config.initial_history_in_client_content is False
1513-
1590+
assert (
1591+
call_req.live_connect_config.history_config.initial_history_in_client_content
1592+
is False
1593+
)

0 commit comments

Comments
 (0)