|
17 | 17 | from unittest import mock |
18 | 18 | from unittest.mock import AsyncMock |
19 | 19 |
|
| 20 | +from google.adk.agents.live_request_queue import LiveRequestQueue |
20 | 21 | from google.adk.agents.llm_agent import Agent |
21 | 22 | from google.adk.agents.run_config import RunConfig |
22 | 23 | from google.adk.events.event import Event |
23 | 24 | from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback |
24 | 25 | from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow |
25 | | -from google.adk.models.google_llm import Gemini, GoogleLLMVariant |
| 26 | +from google.adk.models.google_llm import Gemini |
| 27 | +from google.adk.models.google_llm import GoogleLLMVariant |
26 | 28 | from google.adk.models.llm_request import LlmRequest |
27 | 29 | from google.adk.models.llm_response import LlmResponse |
28 | 30 | from google.adk.plugins.base_plugin import BasePlugin |
29 | 31 | from google.adk.tools.base_toolset import BaseToolset |
30 | 32 | from google.adk.tools.google_search_tool import GoogleSearchTool |
31 | 33 | from google.genai import types |
32 | 34 | import pytest |
| 35 | +from websockets.exceptions import ConnectionClosed |
33 | 36 |
|
34 | 37 | from ... import testing_utils |
35 | 38 |
|
@@ -490,8 +493,6 @@ async def call(self, **kwargs): |
490 | 493 | @pytest.mark.asyncio |
491 | 494 | async def test_run_live_reconnects_on_connection_closed(): |
492 | 495 | """Test that run_live reconnects when ConnectionClosed occurs.""" |
493 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
494 | | - from websockets.exceptions import ConnectionClosed |
495 | 496 |
|
496 | 497 | real_model = Gemini() |
497 | 498 | mock_connection = mock.AsyncMock() |
@@ -558,7 +559,6 @@ async def mock_receive_2(): |
558 | 559 | @pytest.mark.asyncio |
559 | 560 | async def test_run_live_reconnects_on_api_error(): |
560 | 561 | """Test that run_live reconnects when APIError occurs.""" |
561 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
562 | 562 | from google.genai.errors import APIError |
563 | 563 |
|
564 | 564 | real_model = Gemini() |
@@ -626,7 +626,6 @@ async def mock_receive_2(): |
626 | 626 | @pytest.mark.asyncio |
627 | 627 | async def test_run_live_skips_send_history_on_resumption(): |
628 | 628 | """Test that run_live skips send_history when resuming a session.""" |
629 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
630 | 629 |
|
631 | 630 | real_model = Gemini() |
632 | 631 | mock_connection = mock.AsyncMock() |
@@ -684,7 +683,6 @@ async def mock_receive(): |
684 | 683 | @pytest.mark.asyncio |
685 | 684 | async def test_live_session_resumption_go_away(): |
686 | 685 | """Test that go_away triggers reconnection.""" |
687 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
688 | 686 |
|
689 | 687 | real_model = Gemini() |
690 | 688 | mock_connection = mock.AsyncMock() |
@@ -743,8 +741,6 @@ async def mock_receive_2(): |
743 | 741 | @pytest.mark.asyncio |
744 | 742 | async def test_run_live_no_reconnect_without_handle(): |
745 | 743 | """Test that run_live does not reconnect when handle is missing.""" |
746 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
747 | | - from websockets.exceptions import ConnectionClosed |
748 | 744 |
|
749 | 745 | real_model = Gemini() |
750 | 746 | mock_connection = mock.AsyncMock() |
@@ -786,8 +782,6 @@ async def mock_receive(): |
786 | 782 | @pytest.mark.asyncio |
787 | 783 | async def test_run_live_reconnect_limit(): |
788 | 784 | """Test that run_live stops reconnecting after 5 attempts.""" |
789 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
790 | | - from websockets.exceptions import ConnectionClosed |
791 | 785 |
|
792 | 786 | real_model = Gemini() |
793 | 787 |
|
@@ -843,9 +837,7 @@ async def mock_receive(): |
843 | 837 | @pytest.mark.asyncio |
844 | 838 | async def test_run_live_reconnect_reset_attempt(): |
845 | 839 | """Test that attempt counter is reset on successful communication.""" |
846 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
847 | 840 | from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS |
848 | | - from websockets.exceptions import ConnectionClosed |
849 | 841 |
|
850 | 842 | real_model = Gemini() |
851 | 843 |
|
@@ -987,7 +979,6 @@ async def mock_receive(): |
987 | 979 | @pytest.mark.asyncio |
988 | 980 | async def test_run_live_clears_resumption_handle_on_transfer(): |
989 | 981 | """Test that run_live clears session resumption handles when transferring to another agent.""" |
990 | | - from google.adk.agents.live_request_queue import LiveRequestQueue |
991 | 982 |
|
992 | 983 | agent = Agent(name='test_agent') |
993 | 984 | invocation_context = await testing_utils.create_invocation_context( |
@@ -1184,21 +1175,27 @@ async def mock_receive_2(): |
1184 | 1175 | mock_aenter = mock.AsyncMock() |
1185 | 1176 | mock_aenter.side_effect = [mock_connection, mock_connection_2] |
1186 | 1177 |
|
1187 | | - with mock.patch( |
1188 | | - 'google.adk.models.google_llm.Gemini.connect' |
1189 | | - ) as mock_connect: |
1190 | | - mock_connect.return_value.__aenter__ = mock_aenter |
| 1178 | + with mock.patch.object( |
| 1179 | + Gemini, '_api_backend', new_callable=mock.PropertyMock |
| 1180 | + ) as mock_backend: |
| 1181 | + mock_backend.return_value = GoogleLLMVariant.GEMINI_API |
| 1182 | + with mock.patch( |
| 1183 | + 'google.adk.models.google_llm.Gemini.connect' |
| 1184 | + ) as mock_connect: |
| 1185 | + mock_connect.return_value.__aenter__ = mock_aenter |
1191 | 1186 |
|
1192 | | - try: |
1193 | | - async for _ in flow.run_live(invocation_context): |
| 1187 | + try: |
| 1188 | + async for _ in flow.run_live(invocation_context): |
| 1189 | + pass |
| 1190 | + except StopTestError: |
1194 | 1191 | pass |
1195 | | - except StopTestError: |
1196 | | - pass |
1197 | 1192 |
|
1198 | | - assert mock_connect.call_count == 2 |
1199 | | - second_call_req = mock_connect.call_args_list[1][0][0] |
1200 | | - session_resump = second_call_req.live_connect_config.session_resumption |
1201 | | - assert session_resump.transparent is None |
| 1193 | + assert mock_connect.call_count == 2 |
| 1194 | + second_call_req = mock_connect.call_args_list[1][0][0] |
| 1195 | + session_resump = ( |
| 1196 | + second_call_req.live_connect_config.session_resumption |
| 1197 | + ) |
| 1198 | + assert session_resump.transparent is None |
1202 | 1199 |
|
1203 | 1200 |
|
1204 | 1201 | @pytest.mark.asyncio |
@@ -1275,7 +1272,7 @@ async def mock_receive_2(): |
1275 | 1272 |
|
1276 | 1273 | @pytest.mark.asyncio |
1277 | 1274 | @pytest.mark.parametrize( |
1278 | | - "api_backend,should_have_history_config", |
| 1275 | + 'api_backend,should_have_history_config', |
1279 | 1276 | [ |
1280 | 1277 | (GoogleLLMVariant.GEMINI_API, True), |
1281 | 1278 | (GoogleLLMVariant.VERTEX_AI, False), |
@@ -1309,8 +1306,12 @@ async def mock_receive(): |
1309 | 1306 | flow = BaseLlmFlowForTesting() |
1310 | 1307 |
|
1311 | 1308 | with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): |
| 1309 | + |
1312 | 1310 | async def mock_preprocess(ctx, req): |
1313 | | - req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] |
| 1311 | + req.model = 'gemini-3.1-flash-live-preview' |
| 1312 | + req.contents = [ |
| 1313 | + types.Content(parts=[types.Part.from_text(text='history')]) |
| 1314 | + ] |
1314 | 1315 | yield Event(id=Event.new_id(), author='test') |
1315 | 1316 |
|
1316 | 1317 | with mock.patch.object( |
|
0 commit comments