|
17 | 17 | from unittest import mock |
18 | 18 | from unittest.mock import AsyncMock |
19 | 19 |
|
| 20 | +from google.adk.agents.live_request_queue import LiveRequestQueue |
20 | 21 | from google.adk.agents.llm_agent import Agent |
21 | 22 | from google.adk.agents.run_config import RunConfig |
22 | 23 | from google.adk.events.event import Event |
|
30 | 31 | from google.adk.tools.google_search_tool import GoogleSearchTool |
31 | 32 | from google.genai import types |
32 | 33 | import pytest |
| 34 | +from websockets.exceptions import ConnectionClosed |
33 | 35 |
|
34 | 36 | from ... import testing_utils |
35 | 37 |
|
@@ -490,8 +492,6 @@ async def call(self, **kwargs): |
490 | 492 | @pytest.mark.asyncio |
491 | 493 | async def test_run_live_reconnects_on_connection_closed(): |
492 | 494 | """Test that run_live reconnects when ConnectionClosed occurs.""" |
493 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
494 | | - from websockets.exceptions import ConnectionClosed |
495 | 495 |
|
496 | 496 | real_model = Gemini() |
497 | 497 | mock_connection = mock.AsyncMock() |
@@ -558,7 +558,6 @@ async def mock_receive_2(): |
558 | 558 | @pytest.mark.asyncio |
559 | 559 | async def test_run_live_reconnects_on_api_error(): |
560 | 560 | """Test that run_live reconnects when APIError occurs.""" |
561 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
562 | 561 | from google.genai.errors import APIError |
563 | 562 |
|
564 | 563 | real_model = Gemini() |
@@ -626,7 +625,6 @@ async def mock_receive_2(): |
626 | 625 | @pytest.mark.asyncio |
627 | 626 | async def test_run_live_skips_send_history_on_resumption(): |
628 | 627 | """Test that run_live skips send_history when resuming a session.""" |
629 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
630 | 628 |
|
631 | 629 | real_model = Gemini() |
632 | 630 | mock_connection = mock.AsyncMock() |
@@ -684,7 +682,6 @@ async def mock_receive(): |
684 | 682 | @pytest.mark.asyncio |
685 | 683 | async def test_live_session_resumption_go_away(): |
686 | 684 | """Test that go_away triggers reconnection.""" |
687 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
688 | 685 |
|
689 | 686 | real_model = Gemini() |
690 | 687 | mock_connection = mock.AsyncMock() |
@@ -743,8 +740,6 @@ async def mock_receive_2(): |
743 | 740 | @pytest.mark.asyncio |
744 | 741 | async def test_run_live_no_reconnect_without_handle(): |
745 | 742 | """Test that run_live does not reconnect when handle is missing.""" |
746 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
747 | | - from websockets.exceptions import ConnectionClosed |
748 | 743 |
|
749 | 744 | real_model = Gemini() |
750 | 745 | mock_connection = mock.AsyncMock() |
@@ -786,8 +781,6 @@ async def mock_receive(): |
786 | 781 | @pytest.mark.asyncio |
787 | 782 | async def test_run_live_reconnect_limit(): |
788 | 783 | """Test that run_live stops reconnecting after 5 attempts.""" |
789 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
790 | | - from websockets.exceptions import ConnectionClosed |
791 | 784 |
|
792 | 785 | real_model = Gemini() |
793 | 786 |
|
@@ -843,9 +836,7 @@ async def mock_receive(): |
843 | 836 | @pytest.mark.asyncio |
844 | 837 | async def test_run_live_reconnect_reset_attempt(): |
845 | 838 | """Test that attempt counter is reset on successful communication.""" |
846 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
847 | 839 | from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS |
848 | | - from websockets.exceptions import ConnectionClosed |
849 | 840 |
|
850 | 841 | real_model = Gemini() |
851 | 842 |
|
@@ -987,7 +978,6 @@ async def mock_receive(): |
987 | 978 | @pytest.mark.asyncio |
988 | 979 | async def test_run_live_clears_resumption_handle_on_transfer(): |
989 | 980 | """Test that run_live clears session resumption handles when transferring to another agent.""" |
990 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
991 | 981 |
|
992 | 982 | agent = Agent(name='test_agent') |
993 | 983 | invocation_context = await testing_utils.create_invocation_context( |
@@ -1129,3 +1119,145 @@ async def test_postprocess_async_yields_grounding_metadata_only(): |
1129 | 1119 |
|
1130 | 1120 | assert len(events) == 1 |
1131 | 1121 | assert events[0].grounding_metadata == grounding_metadata |
| 1122 | + |
| 1123 | + |
| 1124 | +@pytest.mark.asyncio |
| 1125 | +async def test_run_live_reconnect_does_not_set_transparent(): |
| 1126 | + """Test that run_live reconnect does not set transparent=True.""" |
| 1127 | + |
| 1128 | + real_model = Gemini() |
| 1129 | + mock_connection = mock.AsyncMock() |
| 1130 | + |
| 1131 | + async def mock_receive(): |
| 1132 | + yield LlmResponse( |
| 1133 | + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( |
| 1134 | + new_handle='test_handle' |
| 1135 | + ) |
| 1136 | + ) |
| 1137 | + raise ConnectionClosed(None, None) |
| 1138 | + |
| 1139 | + mock_connection.receive = mock.Mock(side_effect=mock_receive) |
| 1140 | + |
| 1141 | + agent = Agent(name='test_agent', model=real_model) |
| 1142 | + invocation_context = await testing_utils.create_invocation_context( |
| 1143 | + agent=agent |
| 1144 | + ) |
| 1145 | + invocation_context.live_request_queue = LiveRequestQueue() |
| 1146 | + invocation_context.run_config = RunConfig() |
| 1147 | + |
| 1148 | + flow = BaseLlmFlowForTesting() |
| 1149 | + |
| 1150 | + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): |
| 1151 | + |
| 1152 | + async def mock_preprocess(ctx, req): |
| 1153 | + req.live_connect_config.session_resumption = ( |
| 1154 | + ctx.run_config.session_resumption |
| 1155 | + ) |
| 1156 | + yield Event(id=Event.new_id(), author='test') |
| 1157 | + |
| 1158 | + with mock.patch.object( |
| 1159 | + flow, '_preprocess_async', side_effect=mock_preprocess |
| 1160 | + ): |
| 1161 | + mock_connection_2 = mock.AsyncMock() |
| 1162 | + |
| 1163 | + class StopTestError(Exception): |
| 1164 | + pass |
| 1165 | + |
| 1166 | + async def mock_receive_2(): |
| 1167 | + yield LlmResponse( |
| 1168 | + content=types.Content(parts=[types.Part.from_text(text='hi')]) |
| 1169 | + ) |
| 1170 | + raise StopTestError('stop') |
| 1171 | + |
| 1172 | + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) |
| 1173 | + |
| 1174 | + mock_aenter = mock.AsyncMock() |
| 1175 | + mock_aenter.side_effect = [mock_connection, mock_connection_2] |
| 1176 | + |
| 1177 | + with mock.patch( |
| 1178 | + 'google.adk.models.google_llm.Gemini.connect' |
| 1179 | + ) as mock_connect: |
| 1180 | + mock_connect.return_value.__aenter__ = mock_aenter |
| 1181 | + |
| 1182 | + try: |
| 1183 | + async for _ in flow.run_live(invocation_context): |
| 1184 | + pass |
| 1185 | + except StopTestError: |
| 1186 | + pass |
| 1187 | + |
| 1188 | + assert mock_connect.call_count == 2 |
| 1189 | + second_call_req = mock_connect.call_args_list[1][0][0] |
| 1190 | + session_resump = second_call_req.live_connect_config.session_resumption |
| 1191 | + assert session_resump.transparent is None |
| 1192 | + |
| 1193 | + |
| 1194 | +@pytest.mark.asyncio |
| 1195 | +async def test_run_live_reconnect_sets_transparent_for_vertex(): |
| 1196 | + """Test that run_live reconnect sets transparent=True for vertex backend.""" |
| 1197 | + |
| 1198 | + real_model = Gemini( |
| 1199 | + model='projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp' |
| 1200 | + ) |
| 1201 | + mock_connection = mock.AsyncMock() |
| 1202 | + |
| 1203 | + async def mock_receive(): |
| 1204 | + yield LlmResponse( |
| 1205 | + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( |
| 1206 | + new_handle='test_handle' |
| 1207 | + ) |
| 1208 | + ) |
| 1209 | + raise ConnectionClosed(None, None) |
| 1210 | + |
| 1211 | + mock_connection.receive = mock.Mock(side_effect=mock_receive) |
| 1212 | + |
| 1213 | + agent = Agent(name='test_agent', model=real_model) |
| 1214 | + invocation_context = await testing_utils.create_invocation_context( |
| 1215 | + agent=agent |
| 1216 | + ) |
| 1217 | + invocation_context.live_request_queue = LiveRequestQueue() |
| 1218 | + invocation_context.run_config = RunConfig() |
| 1219 | + |
| 1220 | + flow = BaseLlmFlowForTesting() |
| 1221 | + |
| 1222 | + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): |
| 1223 | + |
| 1224 | + async def mock_preprocess(ctx, req): |
| 1225 | + req.live_connect_config.session_resumption = ( |
| 1226 | + ctx.run_config.session_resumption |
| 1227 | + ) |
| 1228 | + yield Event(id=Event.new_id(), author='test') |
| 1229 | + |
| 1230 | + with mock.patch.object( |
| 1231 | + flow, '_preprocess_async', side_effect=mock_preprocess |
| 1232 | + ): |
| 1233 | + mock_connection_2 = mock.AsyncMock() |
| 1234 | + |
| 1235 | + class StopTestError(Exception): |
| 1236 | + pass |
| 1237 | + |
| 1238 | + async def mock_receive_2(): |
| 1239 | + yield LlmResponse( |
| 1240 | + content=types.Content(parts=[types.Part.from_text(text='hi')]) |
| 1241 | + ) |
| 1242 | + raise StopTestError('stop') |
| 1243 | + |
| 1244 | + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) |
| 1245 | + |
| 1246 | + mock_aenter = mock.AsyncMock() |
| 1247 | + mock_aenter.side_effect = [mock_connection, mock_connection_2] |
| 1248 | + |
| 1249 | + with mock.patch( |
| 1250 | + 'google.adk.models.google_llm.Gemini.connect' |
| 1251 | + ) as mock_connect: |
| 1252 | + mock_connect.return_value.__aenter__ = mock_aenter |
| 1253 | + |
| 1254 | + try: |
| 1255 | + async for _ in flow.run_live(invocation_context): |
| 1256 | + pass |
| 1257 | + except StopTestError: |
| 1258 | + pass |
| 1259 | + |
| 1260 | + assert mock_connect.call_count == 2 |
| 1261 | + second_call_req = mock_connect.call_args_list[1][0][0] |
| 1262 | + session_resump = second_call_req.live_connect_config.session_resumption |
| 1263 | + assert session_resump.transparent |
0 commit comments