|
16 | 16 |
|
17 | 17 | import base64 |
18 | 18 | import json |
| 19 | +from unittest.mock import AsyncMock |
19 | 20 | from unittest.mock import MagicMock |
20 | 21 |
|
21 | 22 | from google.adk.models import interactions_utils |
22 | 23 | from google.adk.models.llm_request import LlmRequest |
23 | 24 | from google.genai import types |
| 25 | +from google.genai._interactions.types import ContentDelta |
| 26 | +from google.genai._interactions.types import ContentStop |
| 27 | +from google.genai._interactions.types import Interaction |
| 28 | +from google.genai._interactions.types import InteractionCompleteEvent |
| 29 | +from google.genai._interactions.types import InteractionStartEvent |
| 30 | +from google.genai._interactions.types import InteractionStatusUpdate |
| 31 | +from google.genai._interactions.types.content_delta import DeltaFunctionCall |
| 32 | +import pytest |
| 33 | + |
| 34 | + |
| 35 | +class _MockAsyncIterator: |
| 36 | + """Simple async iterator for streaming test events.""" |
| 37 | + |
| 38 | + def __init__(self, values): |
| 39 | + self._iterator = iter(values) |
| 40 | + |
| 41 | + def __aiter__(self): |
| 42 | + return self |
| 43 | + |
| 44 | + async def __anext__(self): |
| 45 | + try: |
| 46 | + return next(self._iterator) |
| 47 | + except StopIteration as exc: |
| 48 | + raise StopAsyncIteration from exc |
24 | 49 |
|
25 | 50 |
|
26 | 51 | class TestConvertPartToInteractionContent: |
@@ -955,3 +980,100 @@ def test_unknown_event_type_returns_none(self): |
955 | 980 |
|
956 | 981 | assert result is None |
957 | 982 | assert not aggregated_parts |
| 983 | + |
| 984 | + def test_interaction_complete_event(self): |
| 985 | + """Test converting an interaction.complete event.""" |
| 986 | + interaction = Interaction( |
| 987 | + id='int_complete', |
| 988 | + created='2026-04-07T00:00:00Z', |
| 989 | + updated='2026-04-07T00:00:01Z', |
| 990 | + status='completed', |
| 991 | + outputs=[{'type': 'text', 'text': 'Done'}], |
| 992 | + ) |
| 993 | + event = InteractionCompleteEvent( |
| 994 | + event_type='interaction.complete', |
| 995 | + interaction=interaction, |
| 996 | + ) |
| 997 | + |
| 998 | + result = interactions_utils.convert_interaction_event_to_llm_response( |
| 999 | + event, aggregated_parts=[] |
| 1000 | + ) |
| 1001 | + |
| 1002 | + assert result is not None |
| 1003 | + assert result.interaction_id == 'int_complete' |
| 1004 | + assert result.content.parts[0].text == 'Done' |
| 1005 | + assert result.turn_complete is True |
| 1006 | + |
| 1007 | + |
| 1008 | +class TestGenerateContentViaInteractions: |
| 1009 | + """Tests for generate_content_via_interactions.""" |
| 1010 | + |
| 1011 | + @pytest.mark.asyncio |
| 1012 | + async def test_stream_uses_interaction_start_id_for_function_calls(self): |
| 1013 | + """Test that streaming function calls retain the interaction chain ID.""" |
| 1014 | + interaction = Interaction( |
| 1015 | + id='int_stream_123', |
| 1016 | + created='2026-04-07T00:00:00Z', |
| 1017 | + updated='2026-04-07T00:00:01Z', |
| 1018 | + status='requires_action', |
| 1019 | + ) |
| 1020 | + stream_events = [ |
| 1021 | + InteractionStartEvent( |
| 1022 | + event_type='interaction.start', |
| 1023 | + interaction=interaction, |
| 1024 | + ), |
| 1025 | + ContentDelta( |
| 1026 | + event_type='content.delta', |
| 1027 | + index=0, |
| 1028 | + delta=DeltaFunctionCall( |
| 1029 | + type='function_call', |
| 1030 | + id='fc_123', |
| 1031 | + name='get_weather', |
| 1032 | + arguments={'city': 'Tokyo'}, |
| 1033 | + ), |
| 1034 | + ), |
| 1035 | + ContentStop(event_type='content.stop', index=0), |
| 1036 | + InteractionStatusUpdate( |
| 1037 | + event_type='interaction.status_update', |
| 1038 | + interaction_id='int_stream_123', |
| 1039 | + status='requires_action', |
| 1040 | + ), |
| 1041 | + ] |
| 1042 | + api_client = MagicMock() |
| 1043 | + api_client.aio.interactions.create = AsyncMock( |
| 1044 | + return_value=_MockAsyncIterator(stream_events) |
| 1045 | + ) |
| 1046 | + llm_request = LlmRequest( |
| 1047 | + model='gemini-2.5-flash', |
| 1048 | + contents=[ |
| 1049 | + types.Content( |
| 1050 | + role='user', |
| 1051 | + parts=[types.Part.from_text(text='Weather in Tokyo?')], |
| 1052 | + ) |
| 1053 | + ], |
| 1054 | + config=types.GenerateContentConfig(), |
| 1055 | + ) |
| 1056 | + |
| 1057 | + responses = [ |
| 1058 | + response |
| 1059 | + async for response in ( |
| 1060 | + interactions_utils.generate_content_via_interactions( |
| 1061 | + api_client=api_client, |
| 1062 | + llm_request=llm_request, |
| 1063 | + stream=True, |
| 1064 | + ) |
| 1065 | + ) |
| 1066 | + ] |
| 1067 | + |
| 1068 | + function_call_response = next( |
| 1069 | + response |
| 1070 | + for response in responses |
| 1071 | + if response.content |
| 1072 | + and response.content.parts |
| 1073 | + and response.content.parts[0].function_call |
| 1074 | + ) |
| 1075 | + |
| 1076 | + assert function_call_response.interaction_id == 'int_stream_123' |
| 1077 | + assert function_call_response.content.parts[0].function_call.name == ( |
| 1078 | + 'get_weather' |
| 1079 | + ) |
0 commit comments