|
27 | 27 | from google.adk.models.llm_response import LlmResponse |
28 | 28 | from google.adk.utils.variant_utils import GoogleLLMVariant |
29 | 29 | from google.genai import types |
30 | | -from google.genai import version as genai_version |
31 | 30 | from google.genai.types import Content |
32 | 31 | from google.genai.types import Part |
33 | 32 | import pytest |
34 | 33 |
|
35 | 34 |
|
| 35 | +class MockAsyncIterator: |
| 36 | + """Mock for async iterator.""" |
| 37 | + |
| 38 | + def __init__(self, seq): |
| 39 | + self.iter = iter(seq) |
| 40 | + |
| 41 | + def __aiter__(self): |
| 42 | + return self |
| 43 | + |
| 44 | + async def __anext__(self): |
| 45 | + try: |
| 46 | + return next(self.iter) |
| 47 | + except StopIteration as exc: |
| 48 | + raise StopAsyncIteration from exc |
| 49 | + |
| 50 | + |
36 | 51 | @pytest.fixture |
37 | 52 | def generate_content_response(): |
38 | 53 | return types.GenerateContentResponse( |
@@ -215,21 +230,6 @@ async def mock_coro(): |
215 | 230 | @pytest.mark.asyncio |
216 | 231 | async def test_generate_content_async_stream(gemini_llm, llm_request): |
217 | 232 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
218 | | - # Create mock stream responses |
219 | | - class MockAsyncIterator: |
220 | | - |
221 | | - def __init__(self, seq): |
222 | | - self.iter = iter(seq) |
223 | | - |
224 | | - def __aiter__(self): |
225 | | - return self |
226 | | - |
227 | | - async def __anext__(self): |
228 | | - try: |
229 | | - return next(self.iter) |
230 | | - except StopIteration: |
231 | | - raise StopAsyncIteration |
232 | | - |
233 | 233 | mock_responses = [ |
234 | 234 | types.GenerateContentResponse( |
235 | 235 | candidates=[ |
@@ -292,21 +292,6 @@ async def test_generate_content_async_stream_preserves_thinking_and_text_parts( |
292 | 292 | gemini_llm, llm_request |
293 | 293 | ): |
294 | 294 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
295 | | - |
296 | | - class MockAsyncIterator: |
297 | | - |
298 | | - def __init__(self, seq): |
299 | | - self._iter = iter(seq) |
300 | | - |
301 | | - def __aiter__(self): |
302 | | - return self |
303 | | - |
304 | | - async def __anext__(self): |
305 | | - try: |
306 | | - return next(self._iter) |
307 | | - except StopIteration: |
308 | | - raise StopAsyncIteration |
309 | | - |
310 | 295 | response1 = types.GenerateContentResponse( |
311 | 296 | candidates=[ |
312 | 297 | types.Candidate( |
@@ -436,21 +421,6 @@ async def test_generate_content_async_stream_with_custom_headers( |
436 | 421 | llm_request.config.http_options = types.HttpOptions(headers=custom_headers) |
437 | 422 |
|
438 | 423 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
439 | | - # Create mock stream responses |
440 | | - class MockAsyncIterator: |
441 | | - |
442 | | - def __init__(self, seq): |
443 | | - self.iter = iter(seq) |
444 | | - |
445 | | - def __aiter__(self): |
446 | | - return self |
447 | | - |
448 | | - async def __anext__(self): |
449 | | - try: |
450 | | - return next(self.iter) |
451 | | - except StopIteration: |
452 | | - raise StopAsyncIteration |
453 | | - |
454 | 424 | mock_responses = [ |
455 | 425 | types.GenerateContentResponse( |
456 | 426 | candidates=[ |
@@ -488,35 +458,58 @@ async def mock_coro(): |
488 | 458 | assert len(responses) == 2 |
489 | 459 |
|
490 | 460 |
|
| 461 | +@pytest.mark.parametrize("stream", [True, False]) |
491 | 462 | @pytest.mark.asyncio |
492 | | -async def test_generate_content_async_without_custom_headers( |
493 | | - gemini_llm, llm_request, generate_content_response |
| 463 | +async def test_generate_content_async_patches_tracking_headers( |
| 464 | + stream, gemini_llm, llm_request, generate_content_response |
494 | 465 | ): |
495 | | - """Test that tracking headers are not modified when no custom headers exist.""" |
496 | | - # Ensure no http_options exist initially |
| 466 | + """Tests that tracking headers are added to the request config.""" |
| 467 | + # Set the request's config.http_options to None. |
497 | 468 | llm_request.config.http_options = None |
498 | 469 |
|
499 | 470 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
| 471 | + if stream: |
| 472 | + # Create a mock coroutine that returns the mock_responses. |
| 473 | + async def mock_coro(): |
| 474 | + return MockAsyncIterator([generate_content_response]) |
500 | 475 |
|
501 | | - async def mock_coro(): |
502 | | - return generate_content_response |
| 476 | + # Mock for streaming response. |
| 477 | + mock_client.aio.models.generate_content_stream.return_value = mock_coro() |
| 478 | + else: |
| 479 | + # Create a mock coroutine that returns the generate_content_response. |
| 480 | + async def mock_coro(): |
| 481 | + return generate_content_response |
503 | 482 |
|
504 | | - mock_client.aio.models.generate_content.return_value = mock_coro() |
| 483 | + # Mock for non-streaming response. |
| 484 | + mock_client.aio.models.generate_content.return_value = mock_coro() |
505 | 485 |
|
| 486 | + # Call the generate_content_async method. |
506 | 487 | responses = [ |
507 | 488 | resp |
508 | 489 | async for resp in gemini_llm.generate_content_async( |
509 | | - llm_request, stream=False |
| 490 | + llm_request, stream=stream |
510 | 491 | ) |
511 | 492 | ] |
512 | 493 |
|
513 | | - # Verify that the config passed to generate_content has no http_options |
514 | | - mock_client.aio.models.generate_content.assert_called_once() |
515 | | - call_args = mock_client.aio.models.generate_content.call_args |
516 | | - config_arg = call_args.kwargs["config"] |
517 | | - assert config_arg.http_options is None |
| 494 | + # Assert that the config passed to the generate_content or |
| 495 | + # generate_content_stream method contains the tracking headers. |
| 496 | + if stream: |
| 497 | + mock_client.aio.models.generate_content_stream.assert_called_once() |
| 498 | + call_args = mock_client.aio.models.generate_content_stream.call_args |
| 499 | + else: |
| 500 | + mock_client.aio.models.generate_content.assert_called_once() |
| 501 | + call_args = mock_client.aio.models.generate_content.call_args |
518 | 502 |
|
519 | | - assert len(responses) == 1 |
| 503 | + final_config = call_args.kwargs["config"] |
| 504 | + |
| 505 | + assert final_config is not None |
| 506 | + assert final_config.http_options is not None |
| 507 | + assert ( |
| 508 | + final_config.http_options.headers["x-goog-api-client"] |
| 509 | + == gemini_llm._tracking_headers["x-goog-api-client"] |
| 510 | + ) |
| 511 | + |
| 512 | + assert len(responses) == 2 if stream else 1 |
520 | 513 |
|
521 | 514 |
|
522 | 515 | def test_live_api_version_vertex_ai(gemini_llm): |
@@ -665,8 +658,7 @@ async def test_preprocess_request_handles_backend_specific_fields( |
665 | 658 | expected_inline_display_name: Optional[str], |
666 | 659 | expected_labels: Optional[str], |
667 | 660 | ): |
668 | | - """ |
669 | | - Tests that _preprocess_request correctly sanitizes fields based on the API backend. |
| 661 | + """Tests that _preprocess_request correctly sanitizes fields based on the API backend. |
670 | 662 |
|
671 | 663 | - For GEMINI_API, it should remove 'display_name' from file/inline data |
672 | 664 | and remove 'labels' from the config. |
@@ -732,21 +724,6 @@ async def test_generate_content_async_stream_aggregated_content_regardless_of_fi |
732 | 724 | ) |
733 | 725 |
|
734 | 726 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
735 | | - |
736 | | - class MockAsyncIterator: |
737 | | - |
738 | | - def __init__(self, seq): |
739 | | - self.iter = iter(seq) |
740 | | - |
741 | | - def __aiter__(self): |
742 | | - return self |
743 | | - |
744 | | - async def __anext__(self): |
745 | | - try: |
746 | | - return next(self.iter) |
747 | | - except StopIteration: |
748 | | - raise StopAsyncIteration |
749 | | - |
750 | 727 | # Test with different finish reasons |
751 | 728 | test_cases = [ |
752 | 729 | types.FinishReason.MAX_TOKENS, |
@@ -820,21 +797,6 @@ async def test_generate_content_async_stream_with_thought_and_text_error_handlin |
820 | 797 | ) |
821 | 798 |
|
822 | 799 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
823 | | - |
824 | | - class MockAsyncIterator: |
825 | | - |
826 | | - def __init__(self, seq): |
827 | | - self.iter = iter(seq) |
828 | | - |
829 | | - def __aiter__(self): |
830 | | - return self |
831 | | - |
832 | | - async def __anext__(self): |
833 | | - try: |
834 | | - return next(self.iter) |
835 | | - except StopIteration: |
836 | | - raise StopAsyncIteration |
837 | | - |
838 | 800 | mock_responses = [ |
839 | 801 | types.GenerateContentResponse( |
840 | 802 | candidates=[ |
@@ -902,21 +864,6 @@ async def test_generate_content_async_stream_error_info_none_for_stop_finish_rea |
902 | 864 | ) |
903 | 865 |
|
904 | 866 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
905 | | - |
906 | | - class MockAsyncIterator: |
907 | | - |
908 | | - def __init__(self, seq): |
909 | | - self.iter = iter(seq) |
910 | | - |
911 | | - def __aiter__(self): |
912 | | - return self |
913 | | - |
914 | | - async def __anext__(self): |
915 | | - try: |
916 | | - return next(self.iter) |
917 | | - except StopIteration: |
918 | | - raise StopAsyncIteration |
919 | | - |
920 | 867 | mock_responses = [ |
921 | 868 | types.GenerateContentResponse( |
922 | 869 | candidates=[ |
@@ -980,21 +927,6 @@ async def test_generate_content_async_stream_error_info_set_for_non_stop_finish_ |
980 | 927 | ) |
981 | 928 |
|
982 | 929 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
983 | | - |
984 | | - class MockAsyncIterator: |
985 | | - |
986 | | - def __init__(self, seq): |
987 | | - self.iter = iter(seq) |
988 | | - |
989 | | - def __aiter__(self): |
990 | | - return self |
991 | | - |
992 | | - async def __anext__(self): |
993 | | - try: |
994 | | - return next(self.iter) |
995 | | - except StopIteration: |
996 | | - raise StopAsyncIteration |
997 | | - |
998 | 930 | mock_responses = [ |
999 | 931 | types.GenerateContentResponse( |
1000 | 932 | candidates=[ |
@@ -1058,21 +990,6 @@ async def test_generate_content_async_stream_no_aggregated_content_without_text( |
1058 | 990 | ) |
1059 | 991 |
|
1060 | 992 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
1061 | | - |
1062 | | - class MockAsyncIterator: |
1063 | | - |
1064 | | - def __init__(self, seq): |
1065 | | - self.iter = iter(seq) |
1066 | | - |
1067 | | - def __aiter__(self): |
1068 | | - return self |
1069 | | - |
1070 | | - async def __anext__(self): |
1071 | | - try: |
1072 | | - return next(self.iter) |
1073 | | - except StopIteration: |
1074 | | - raise StopAsyncIteration |
1075 | | - |
1076 | 993 | # Mock response with no text content |
1077 | 994 | mock_responses = [ |
1078 | 995 | types.GenerateContentResponse( |
@@ -1127,21 +1044,6 @@ async def test_generate_content_async_stream_mixed_text_function_call_text(): |
1127 | 1044 | ) |
1128 | 1045 |
|
1129 | 1046 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
1130 | | - |
1131 | | - class MockAsyncIterator: |
1132 | | - |
1133 | | - def __init__(self, seq): |
1134 | | - self.iter = iter(seq) |
1135 | | - |
1136 | | - def __aiter__(self): |
1137 | | - return self |
1138 | | - |
1139 | | - async def __anext__(self): |
1140 | | - try: |
1141 | | - return next(self.iter) |
1142 | | - except StopIteration: |
1143 | | - raise StopAsyncIteration |
1144 | | - |
1145 | 1047 | # Create responses with pattern: text -> function_call -> text |
1146 | 1048 | mock_responses = [ |
1147 | 1049 | # First text chunk |
@@ -1247,21 +1149,6 @@ async def test_generate_content_async_stream_multiple_text_parts_in_single_respo |
1247 | 1149 | ) |
1248 | 1150 |
|
1249 | 1151 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
1250 | | - |
1251 | | - class MockAsyncIterator: |
1252 | | - |
1253 | | - def __init__(self, seq): |
1254 | | - self.iter = iter(seq) |
1255 | | - |
1256 | | - def __aiter__(self): |
1257 | | - return self |
1258 | | - |
1259 | | - async def __anext__(self): |
1260 | | - try: |
1261 | | - return next(self.iter) |
1262 | | - except StopIteration: |
1263 | | - raise StopAsyncIteration |
1264 | | - |
1265 | 1152 | # Create a response with multiple text parts |
1266 | 1153 | mock_responses = [ |
1267 | 1154 | types.GenerateContentResponse( |
@@ -1314,21 +1201,6 @@ async def test_generate_content_async_stream_complex_mixed_thought_text_function |
1314 | 1201 | ) |
1315 | 1202 |
|
1316 | 1203 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
1317 | | - |
1318 | | - class MockAsyncIterator: |
1319 | | - |
1320 | | - def __init__(self, seq): |
1321 | | - self.iter = iter(seq) |
1322 | | - |
1323 | | - def __aiter__(self): |
1324 | | - return self |
1325 | | - |
1326 | | - async def __anext__(self): |
1327 | | - try: |
1328 | | - return next(self.iter) |
1329 | | - except StopIteration: |
1330 | | - raise StopAsyncIteration |
1331 | | - |
1332 | 1204 | # Complex pattern: thought -> text -> function_call -> thought -> text |
1333 | 1205 | mock_responses = [ |
1334 | 1206 | # Thought |
@@ -1450,21 +1322,6 @@ async def test_generate_content_async_stream_two_separate_text_aggregations(): |
1450 | 1322 | ) |
1451 | 1323 |
|
1452 | 1324 | with mock.patch.object(gemini_llm, "api_client") as mock_client: |
1453 | | - |
1454 | | - class MockAsyncIterator: |
1455 | | - |
1456 | | - def __init__(self, seq): |
1457 | | - self.iter = iter(seq) |
1458 | | - |
1459 | | - def __aiter__(self): |
1460 | | - return self |
1461 | | - |
1462 | | - async def __anext__(self): |
1463 | | - try: |
1464 | | - return next(self.iter) |
1465 | | - except StopIteration: |
1466 | | - raise StopAsyncIteration |
1467 | | - |
1468 | 1325 | # Create responses: multiple text chunks -> function_call -> multiple text chunks |
1469 | 1326 | mock_responses = [ |
1470 | 1327 | # First text accumulation (multiple chunks) |
|
0 commit comments