diff --git a/src/google/adk/models/interactions_utils.py b/src/google/adk/models/interactions_utils.py index add8da0c54..bac328f047 100644 --- a/src/google/adk/models/interactions_utils.py +++ b/src/google/adk/models/interactions_utils.py @@ -56,6 +56,41 @@ _NEW_LINE = '\n' +def _extract_event_id_from_interaction_event( + event: 'InteractionSSEEvent', +) -> Optional[str]: + """Extract the SDK event identifier from an interactions SSE event.""" + event_id = getattr(event, 'event_id', None) + if isinstance(event_id, str): + return event_id + + legacy_event_id = getattr(event, 'id', None) + if isinstance(legacy_event_id, str): + return legacy_event_id + + return None + + +def _extract_interaction_id_from_event( + event: 'InteractionSSEEvent', +) -> Optional[str]: + """Extract the interaction chain identifier from an SSE event.""" + interaction = getattr(event, 'interaction', None) + interaction_id = getattr(interaction, 'id', None) + if isinstance(interaction_id, str): + return interaction_id + + event_interaction_id = getattr(event, 'interaction_id', None) + if isinstance(event_interaction_id, str): + return event_interaction_id + + legacy_interaction_id = getattr(event, 'id', None) + if isinstance(legacy_interaction_id, str): + return legacy_interaction_id + + return None + + def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]: """Convert a types.Part to an interaction content dict. @@ -154,12 +189,12 @@ def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]: elif part.thought: # part.thought is a boolean indicating this is a thought part # ThoughtContentParam expects 'signature' (base64 encoded bytes) - result: dict[str, Any] = {'type': 'thought'} + thought_content: dict[str, Any] = {'type': 'thought'} if part.thought_signature is not None: - result['signature'] = base64.b64encode(part.thought_signature).decode( - 'utf-8' - ) - return result + thought_content['signature'] = base64.b64encode( + part.thought_signature + ).decode('utf-8') + return thought_content elif part.code_execution_result is not None: is_error = part.code_execution_result.outcome in ( types.Outcome.OUTCOME_FAILED, @@ -487,6 +522,7 @@ def convert_interaction_event_to_llm_response( from .llm_response import LlmResponse event_type = getattr(event, 'event_type', None) + interaction_id = interaction_id or _extract_interaction_id_from_event(event) if event_type == 'content.delta': delta = event.delta @@ -565,9 +601,10 @@ def convert_interaction_event_to_llm_response( interaction_id=interaction_id, ) - elif event_type == 'interaction': - # Final interaction event with complete data - return convert_interaction_to_llm_response(event) + elif event_type in ('interaction.complete', 'interaction'): + # Final interaction event with complete data. + interaction = getattr(event, 'interaction', event) + return convert_interaction_to_llm_response(interaction) elif event_type == 'interaction.status_update': status = getattr(event, 'status', None) @@ -834,7 +871,7 @@ def build_interactions_event_log(event: InteractionSSEEvent) -> str: A formatted log string describing the event. """ event_type = getattr(event, 'event_type', 'unknown') - event_id = getattr(event, 'id', None) + event_id = _extract_event_id_from_interaction_event(event) details = [] @@ -1014,8 +1051,9 @@ async def generate_content_via_interactions( logger.debug(build_interactions_event_log(event)) # Extract interaction ID from event if available - if hasattr(event, 'id') and event.id: - current_interaction_id = event.id + current_interaction_id = ( + _extract_interaction_id_from_event(event) or current_interaction_id + ) llm_response = convert_interaction_event_to_llm_response( event, aggregated_parts, current_interaction_id ) diff --git a/tests/unittests/models/test_interactions_utils.py b/tests/unittests/models/test_interactions_utils.py index 93dced0f21..a6cb5dd077 100644 --- a/tests/unittests/models/test_interactions_utils.py +++ b/tests/unittests/models/test_interactions_utils.py @@ -16,11 +16,36 @@ import base64 import json +from unittest.mock import AsyncMock from unittest.mock import MagicMock from google.adk.models import interactions_utils from google.adk.models.llm_request import LlmRequest from google.genai import types +from google.genai._interactions.types import ContentDelta +from google.genai._interactions.types import ContentStop +from google.genai._interactions.types import Interaction +from google.genai._interactions.types import InteractionCompleteEvent +from google.genai._interactions.types import InteractionStartEvent +from google.genai._interactions.types import InteractionStatusUpdate +from google.genai._interactions.types.content_delta import DeltaFunctionCall +import pytest + + +class _MockAsyncIterator: + """Simple async iterator for streaming test events.""" + + def __init__(self, values): + self._iterator = iter(values) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._iterator) + except StopIteration as exc: + raise StopAsyncIteration from exc class TestConvertPartToInteractionContent: @@ -955,3 +980,100 @@ def test_unknown_event_type_returns_none(self): assert result is None assert not aggregated_parts + + def test_interaction_complete_event(self): + """Test converting an interaction.complete event.""" + interaction = Interaction( + id='int_complete', + created='2026-04-07T00:00:00Z', + updated='2026-04-07T00:00:01Z', + status='completed', + outputs=[{'type': 'text', 'text': 'Done'}], + ) + event = InteractionCompleteEvent( + event_type='interaction.complete', + interaction=interaction, + ) + + result = interactions_utils.convert_interaction_event_to_llm_response( + event, aggregated_parts=[] + ) + + assert result is not None + assert result.interaction_id == 'int_complete' + assert result.content.parts[0].text == 'Done' + assert result.turn_complete is True + + +class TestGenerateContentViaInteractions: + """Tests for generate_content_via_interactions.""" + + @pytest.mark.asyncio + async def test_stream_uses_interaction_start_id_for_function_calls(self): + """Test that streaming function calls retain the interaction chain ID.""" + interaction = Interaction( + id='int_stream_123', + created='2026-04-07T00:00:00Z', + updated='2026-04-07T00:00:01Z', + status='requires_action', + ) + stream_events = [ + InteractionStartEvent( + event_type='interaction.start', + interaction=interaction, + ), + ContentDelta( + event_type='content.delta', + index=0, + delta=DeltaFunctionCall( + type='function_call', + id='fc_123', + name='get_weather', + arguments={'city': 'Tokyo'}, + ), + ), + ContentStop(event_type='content.stop', index=0), + InteractionStatusUpdate( + event_type='interaction.status_update', + interaction_id='int_stream_123', + status='requires_action', + ), + ] + api_client = MagicMock() + api_client.aio.interactions.create = AsyncMock( + return_value=_MockAsyncIterator(stream_events) + ) + llm_request = LlmRequest( + model='gemini-2.5-flash', + contents=[ + types.Content( + role='user', + parts=[types.Part.from_text(text='Weather in Tokyo?')], + ) + ], + config=types.GenerateContentConfig(), + ) + + responses = [ + response + async for response in ( + interactions_utils.generate_content_via_interactions( + api_client=api_client, + llm_request=llm_request, + stream=True, + ) + ) + ] + + function_call_response = next( + response + for response in responses + if response.content + and response.content.parts + and response.content.parts[0].function_call + ) + + assert function_call_response.interaction_id == 'int_stream_123' + assert function_call_response.content.parts[0].function_call.name == ( + 'get_weather' + )