Skip to content

Commit 0a24cb7

Browse files
committed
fix(models): preserve interactions SSE ids for function calls
Fixes #5169 by reading interaction IDs from the actual GenAI SDK SSE event fields and carrying them through streaming function-call responses. Also adds regression coverage for interaction.complete events and streaming interaction.start chaining.
1 parent bbad9ec commit 0a24cb7

2 files changed

Lines changed: 154 additions & 6 deletions

File tree

src/google/adk/models/interactions_utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,29 @@
5656
_NEW_LINE = '\n'
5757

5858

59+
def _extract_event_id_from_interaction_event(
60+
event: 'InteractionSSEEvent',
61+
) -> Optional[str]:
62+
"""Extract the SDK event identifier from an interactions SSE event."""
63+
return getattr(event, 'event_id', None) or getattr(event, 'id', None)
64+
65+
66+
def _extract_interaction_id_from_event(
67+
event: 'InteractionSSEEvent',
68+
) -> Optional[str]:
69+
"""Extract the interaction chain identifier from an SSE event."""
70+
interaction = getattr(event, 'interaction', None)
71+
if interaction and getattr(interaction, 'id', None):
72+
return interaction.id
73+
74+
interaction_id = getattr(event, 'interaction_id', None)
75+
if interaction_id:
76+
return interaction_id
77+
78+
# Fall back to legacy field names when older SDK shapes are in play.
79+
return getattr(event, 'id', None)
80+
81+
5982
def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]:
6083
"""Convert a types.Part to an interaction content dict.
6184
@@ -487,6 +510,7 @@ def convert_interaction_event_to_llm_response(
487510
from .llm_response import LlmResponse
488511

489512
event_type = getattr(event, 'event_type', None)
513+
interaction_id = interaction_id or _extract_interaction_id_from_event(event)
490514

491515
if event_type == 'content.delta':
492516
delta = event.delta
@@ -565,9 +589,10 @@ def convert_interaction_event_to_llm_response(
565589
interaction_id=interaction_id,
566590
)
567591

568-
elif event_type == 'interaction':
569-
# Final interaction event with complete data
570-
return convert_interaction_to_llm_response(event)
592+
elif event_type in ('interaction.complete', 'interaction'):
593+
# Final interaction event with complete data.
594+
interaction = getattr(event, 'interaction', event)
595+
return convert_interaction_to_llm_response(interaction)
571596

572597
elif event_type == 'interaction.status_update':
573598
status = getattr(event, 'status', None)
@@ -834,7 +859,7 @@ def build_interactions_event_log(event: InteractionSSEEvent) -> str:
834859
A formatted log string describing the event.
835860
"""
836861
event_type = getattr(event, 'event_type', 'unknown')
837-
event_id = getattr(event, 'id', None)
862+
event_id = _extract_event_id_from_interaction_event(event)
838863

839864
details = []
840865

@@ -1014,8 +1039,9 @@ async def generate_content_via_interactions(
10141039
logger.debug(build_interactions_event_log(event))
10151040

10161041
# Extract interaction ID from event if available
1017-
if hasattr(event, 'id') and event.id:
1018-
current_interaction_id = event.id
1042+
current_interaction_id = (
1043+
_extract_interaction_id_from_event(event) or current_interaction_id
1044+
)
10191045
llm_response = convert_interaction_event_to_llm_response(
10201046
event, aggregated_parts, current_interaction_id
10211047
)

tests/unittests/models/test_interactions_utils.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,36 @@
1616

1717
import base64
1818
import json
19+
from unittest.mock import AsyncMock
1920
from unittest.mock import MagicMock
2021

2122
from google.adk.models import interactions_utils
2223
from google.adk.models.llm_request import LlmRequest
2324
from google.genai import types
25+
from google.genai._interactions.types import ContentDelta
26+
from google.genai._interactions.types import ContentStop
27+
from google.genai._interactions.types import Interaction
28+
from google.genai._interactions.types import InteractionCompleteEvent
29+
from google.genai._interactions.types import InteractionStartEvent
30+
from google.genai._interactions.types import InteractionStatusUpdate
31+
from google.genai._interactions.types.content_delta import DeltaFunctionCall
32+
import pytest
33+
34+
35+
class _MockAsyncIterator:
36+
"""Simple async iterator for streaming test events."""
37+
38+
def __init__(self, values):
39+
self._iterator = iter(values)
40+
41+
def __aiter__(self):
42+
return self
43+
44+
async def __anext__(self):
45+
try:
46+
return next(self._iterator)
47+
except StopIteration as exc:
48+
raise StopAsyncIteration from exc
2449

2550

2651
class TestConvertPartToInteractionContent:
@@ -955,3 +980,100 @@ def test_unknown_event_type_returns_none(self):
955980

956981
assert result is None
957982
assert not aggregated_parts
983+
984+
def test_interaction_complete_event(self):
985+
"""Test converting an interaction.complete event."""
986+
interaction = Interaction(
987+
id='int_complete',
988+
created='2026-04-07T00:00:00Z',
989+
updated='2026-04-07T00:00:01Z',
990+
status='completed',
991+
outputs=[{'type': 'text', 'text': 'Done'}],
992+
)
993+
event = InteractionCompleteEvent(
994+
event_type='interaction.complete',
995+
interaction=interaction,
996+
)
997+
998+
result = interactions_utils.convert_interaction_event_to_llm_response(
999+
event, aggregated_parts=[]
1000+
)
1001+
1002+
assert result is not None
1003+
assert result.interaction_id == 'int_complete'
1004+
assert result.content.parts[0].text == 'Done'
1005+
assert result.turn_complete is True
1006+
1007+
1008+
class TestGenerateContentViaInteractions:
1009+
"""Tests for generate_content_via_interactions."""
1010+
1011+
@pytest.mark.asyncio
1012+
async def test_stream_uses_interaction_start_id_for_function_calls(self):
1013+
"""Test that streaming function calls retain the interaction chain ID."""
1014+
interaction = Interaction(
1015+
id='int_stream_123',
1016+
created='2026-04-07T00:00:00Z',
1017+
updated='2026-04-07T00:00:01Z',
1018+
status='requires_action',
1019+
)
1020+
stream_events = [
1021+
InteractionStartEvent(
1022+
event_type='interaction.start',
1023+
interaction=interaction,
1024+
),
1025+
ContentDelta(
1026+
event_type='content.delta',
1027+
index=0,
1028+
delta=DeltaFunctionCall(
1029+
type='function_call',
1030+
id='fc_123',
1031+
name='get_weather',
1032+
arguments={'city': 'Tokyo'},
1033+
),
1034+
),
1035+
ContentStop(event_type='content.stop', index=0),
1036+
InteractionStatusUpdate(
1037+
event_type='interaction.status_update',
1038+
interaction_id='int_stream_123',
1039+
status='requires_action',
1040+
),
1041+
]
1042+
api_client = MagicMock()
1043+
api_client.aio.interactions.create = AsyncMock(
1044+
return_value=_MockAsyncIterator(stream_events)
1045+
)
1046+
llm_request = LlmRequest(
1047+
model='gemini-2.5-flash',
1048+
contents=[
1049+
types.Content(
1050+
role='user',
1051+
parts=[types.Part.from_text(text='Weather in Tokyo?')],
1052+
)
1053+
],
1054+
config=types.GenerateContentConfig(),
1055+
)
1056+
1057+
responses = [
1058+
response
1059+
async for response in (
1060+
interactions_utils.generate_content_via_interactions(
1061+
api_client=api_client,
1062+
llm_request=llm_request,
1063+
stream=True,
1064+
)
1065+
)
1066+
]
1067+
1068+
function_call_response = next(
1069+
response
1070+
for response in responses
1071+
if response.content
1072+
and response.content.parts
1073+
and response.content.parts[0].function_call
1074+
)
1075+
1076+
assert function_call_response.interaction_id == 'int_stream_123'
1077+
assert function_call_response.content.parts[0].function_call.name == (
1078+
'get_weather'
1079+
)

0 commit comments

Comments
 (0)