|
30 | 30 | from google.adk.plugins.base_plugin import BasePlugin |
31 | 31 | from google.adk.tools.base_toolset import BaseToolset |
32 | 32 | from google.adk.tools.google_search_tool import GoogleSearchTool |
| 33 | +from google.adk.utils.variant_utils import GoogleLLMVariant |
33 | 34 | from google.genai import types |
34 | 35 | import pytest |
35 | 36 | from websockets.exceptions import ConnectionClosed |
@@ -1386,3 +1387,141 @@ async def mock_receive_2(): |
1386 | 1387 | second_call_req = mock_connect.call_args_list[1][0][0] |
1387 | 1388 | session_resump = second_call_req.live_connect_config.session_resumption |
1388 | 1389 | assert session_resump.transparent |
| 1390 | + |
| 1391 | + |
| 1392 | +@pytest.mark.asyncio |
| 1393 | +@pytest.mark.parametrize( |
| 1394 | + 'api_backend', |
| 1395 | + [ |
| 1396 | + GoogleLLMVariant.GEMINI_API, |
| 1397 | + GoogleLLMVariant.VERTEX_AI, |
| 1398 | + ], |
| 1399 | +) |
| 1400 | +async def test_run_live_history_config_set_for_all_backends(api_backend): |
| 1401 | + """Test that run_live sets history_config for all backends.""" |
| 1402 | + |
| 1403 | + real_model = Gemini(model='gemini-3.1-flash-live-preview') |
| 1404 | + mock_connection = mock.AsyncMock() |
| 1405 | + |
| 1406 | + agent = Agent(name='test_agent', model=real_model) |
| 1407 | + invocation_context = await testing_utils.create_invocation_context( |
| 1408 | + agent=agent |
| 1409 | + ) |
| 1410 | + invocation_context.live_request_queue = LiveRequestQueue() |
| 1411 | + invocation_context.run_config = RunConfig() |
| 1412 | + |
| 1413 | + flow = BaseLlmFlowForTesting() |
| 1414 | + |
| 1415 | + async def mock_preprocess(ctx, req): |
| 1416 | + req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] |
| 1417 | + from google.adk.flows.llm_flows.basic import _build_basic_request |
| 1418 | + |
| 1419 | + _build_basic_request(ctx, req) |
| 1420 | + yield Event(id=Event.new_id(), author='test') |
| 1421 | + |
| 1422 | + with mock.patch.object( |
| 1423 | + flow, '_preprocess_async', side_effect=mock_preprocess |
| 1424 | + ): |
| 1425 | + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): |
| 1426 | + |
| 1427 | + class StopTestError(Exception): |
| 1428 | + pass |
| 1429 | + |
| 1430 | + async def mock_receive(): |
| 1431 | + yield LlmResponse( |
| 1432 | + content=types.Content(parts=[types.Part.from_text(text='hi')]) |
| 1433 | + ) |
| 1434 | + raise StopTestError('stop') |
| 1435 | + |
| 1436 | + mock_connection.receive = mock.Mock(side_effect=mock_receive) |
| 1437 | + |
| 1438 | + with mock.patch( |
| 1439 | + 'google.adk.models.google_llm.Gemini.connect' |
| 1440 | + ) as mock_connect: |
| 1441 | + mock_connect.return_value.__aenter__.return_value = mock_connection |
| 1442 | + |
| 1443 | + # Mock the api_backend property |
| 1444 | + with mock.patch.object( |
| 1445 | + Gemini, |
| 1446 | + '_api_backend', |
| 1447 | + new_callable=mock.PropertyMock, |
| 1448 | + return_value=api_backend, |
| 1449 | + ): |
| 1450 | + try: |
| 1451 | + async for _ in flow.run_live(invocation_context): |
| 1452 | + pass |
| 1453 | + except StopTestError: |
| 1454 | + pass |
| 1455 | + |
| 1456 | + assert mock_connect.call_count == 1 |
| 1457 | + called_req = mock_connect.call_args[0][0] |
| 1458 | + assert called_req.live_connect_config is not None |
| 1459 | + assert called_req.live_connect_config.history_config is not None |
| 1460 | + assert ( |
| 1461 | + called_req.live_connect_config.history_config.initial_history_in_client_content |
| 1462 | + is True |
| 1463 | + ) |
| 1464 | + |
| 1465 | + |
| 1466 | +@pytest.mark.asyncio |
| 1467 | +async def test_run_live_respects_explicit_initial_history_in_client_content_false(): |
| 1468 | + """Test that run_live respects explicit initial_history_in_client_content=False in RunConfig.""" |
| 1469 | + |
| 1470 | + real_model = Gemini() |
| 1471 | + mock_connection = mock.AsyncMock() |
| 1472 | + |
| 1473 | + agent = Agent(name='test_agent', model=real_model) |
| 1474 | + invocation_context = await testing_utils.create_invocation_context( |
| 1475 | + agent=agent |
| 1476 | + ) |
| 1477 | + invocation_context.live_request_queue = LiveRequestQueue() |
| 1478 | + run_config = RunConfig( |
| 1479 | + history_config=types.HistoryConfig( |
| 1480 | + initial_history_in_client_content=False |
| 1481 | + ) |
| 1482 | + ) |
| 1483 | + invocation_context.run_config = run_config |
| 1484 | + |
| 1485 | + flow = BaseLlmFlowForTesting() |
| 1486 | + |
| 1487 | + async def mock_preprocess(ctx, req): |
| 1488 | + req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] |
| 1489 | + from google.adk.flows.llm_flows.basic import _build_basic_request |
| 1490 | + |
| 1491 | + _build_basic_request(ctx, req) |
| 1492 | + yield Event(id=Event.new_id(), author='test') |
| 1493 | + |
| 1494 | + with mock.patch.object( |
| 1495 | + flow, '_preprocess_async', side_effect=mock_preprocess |
| 1496 | + ): |
| 1497 | + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): |
| 1498 | + |
| 1499 | + class StopTestError(Exception): |
| 1500 | + pass |
| 1501 | + |
| 1502 | + async def mock_receive(): |
| 1503 | + yield LlmResponse( |
| 1504 | + content=types.Content(parts=[types.Part.from_text(text='hi')]) |
| 1505 | + ) |
| 1506 | + raise StopTestError('stop') |
| 1507 | + |
| 1508 | + mock_connection.receive = mock.Mock(side_effect=mock_receive) |
| 1509 | + |
| 1510 | + with mock.patch( |
| 1511 | + 'google.adk.models.google_llm.Gemini.connect' |
| 1512 | + ) as mock_connect: |
| 1513 | + mock_connect.return_value.__aenter__.return_value = mock_connection |
| 1514 | + |
| 1515 | + try: |
| 1516 | + async for _ in flow.run_live(invocation_context): |
| 1517 | + pass |
| 1518 | + except StopTestError: |
| 1519 | + pass |
| 1520 | + |
| 1521 | + assert mock_connect.call_count == 1 |
| 1522 | + call_req = mock_connect.call_args[0][0] |
| 1523 | + assert call_req.live_connect_config.history_config is not None |
| 1524 | + assert ( |
| 1525 | + call_req.live_connect_config.history_config.initial_history_in_client_content |
| 1526 | + is False |
| 1527 | + ) |
0 commit comments