Skip to content

Commit 2b1c932

Browse files
haranrkcopybara-github
authored andcommitted
refactor(interactions): extract streaming step-delta handlers and stream state
Pull the inline text/image/arguments_delta branches of the Interactions streaming converter into per-type `_handle_*` helpers, introduce a `_StreamState` accumulator in place of the bare `aggregated_parts` list, and share usage-metadata construction via `_usage_metadata_from_interaction`. Pure refactor with no behavior change. Co-authored-by: Haran Rajkumar <haranrk@google.com> PiperOrigin-RevId: 938289986
1 parent b983fcf commit 2b1c932

2 files changed

Lines changed: 143 additions & 108 deletions

File tree

src/google/adk/models/interactions_utils.py

Lines changed: 117 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from __future__ import annotations
3232

3333
import base64
34+
import dataclasses
3435
import json
3536
import logging
3637
from typing import Any
@@ -63,6 +64,7 @@
6364
from google.genai.interactions import ModelOutputStepParam
6465
from google.genai.interactions import Step
6566
from google.genai.interactions import StepDelta
67+
from google.genai.interactions import StepDeltaData
6668
from google.genai.interactions import StepParam
6769
from google.genai.interactions import StepStart
6870
from google.genai.interactions import StepStop
@@ -618,6 +620,28 @@ def _convert_interaction_step_to_parts(step: Step) -> list[types.Part]:
618620
return []
619621

620622

623+
def _usage_metadata_from_interaction(
624+
interaction: Interaction,
625+
) -> types.GenerateContentResponseUsageMetadata | None:
626+
"""Build usage metadata from an interaction's usage, if present.
627+
628+
Shared by the non-streaming converter and the streaming final-event branch so
629+
both surface token counts identically. ``InteractionSseEventInteraction`` (the
630+
type carried by ``InteractionCompletedEvent``) also exposes ``usage``, so this
631+
accepts either interaction type.
632+
"""
633+
if not interaction.usage:
634+
return None
635+
return types.GenerateContentResponseUsageMetadata(
636+
prompt_token_count=interaction.usage.total_input_tokens,
637+
candidates_token_count=interaction.usage.total_output_tokens,
638+
total_token_count=(
639+
(interaction.usage.total_input_tokens or 0)
640+
+ (interaction.usage.total_output_tokens or 0)
641+
),
642+
)
643+
644+
621645
def convert_interaction_to_llm_response(
622646
interaction: Interaction,
623647
) -> LlmResponse:
@@ -658,17 +682,7 @@ def convert_interaction_to_llm_response(
658682
if parts:
659683
content = types.Content(role='model', parts=parts)
660684

661-
# Convert usage metadata if available
662-
usage_metadata = None
663-
if interaction.usage:
664-
usage_metadata = types.GenerateContentResponseUsageMetadata(
665-
prompt_token_count=interaction.usage.total_input_tokens,
666-
candidates_token_count=interaction.usage.total_output_tokens,
667-
total_token_count=(
668-
(interaction.usage.total_input_tokens or 0)
669-
+ (interaction.usage.total_output_tokens or 0)
670-
),
671-
)
685+
usage_metadata = _usage_metadata_from_interaction(interaction)
672686

673687
# Determine finish reason based on status.
674688
# Interaction status can be: 'completed', 'requires_action', 'failed', or
@@ -689,16 +703,92 @@ def convert_interaction_to_llm_response(
689703
)
690704

691705

706+
@dataclasses.dataclass
707+
class _StreamState:
708+
"""Accumulates streamed parts across SSE events.
709+
710+
``parts`` collects ``types.Part``s in arrival order to assemble the final
711+
``Content``.
712+
"""
713+
714+
parts: list[types.Part] = dataclasses.field(default_factory=list)
715+
716+
717+
def _partial_part_response(
718+
part: types.Part, interaction_id: str | None
719+
) -> LlmResponse:
720+
"""Build a partial streaming LlmResponse carrying a single content part."""
721+
return LlmResponse(
722+
content=types.Content(role='model', parts=[part]),
723+
partial=True,
724+
turn_complete=False,
725+
interaction_id=interaction_id,
726+
)
727+
728+
729+
def _handle_text(
730+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
731+
) -> LlmResponse | None:
732+
text = delta.text
733+
if not text:
734+
return None
735+
part = types.Part.from_text(text=text)
736+
state.parts.append(part)
737+
return _partial_part_response(part, interaction_id)
738+
739+
740+
def _handle_media(
741+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
742+
) -> LlmResponse | None:
743+
"""Handle image/audio/video/document deltas (shared data/uri/mime_type)."""
744+
data = delta.data
745+
uri = delta.uri
746+
mime_type = delta.mime_type
747+
if not data and not uri:
748+
return None
749+
if data:
750+
part = types.Part(inline_data=types.Blob(data=data, mime_type=mime_type))
751+
else:
752+
part = types.Part(
753+
file_data=types.FileData(file_uri=uri, mime_type=mime_type)
754+
)
755+
state.parts.append(part)
756+
return _partial_part_response(part, interaction_id)
757+
758+
759+
def _handle_arguments_delta(
760+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
761+
) -> LlmResponse | None:
762+
if not state.parts:
763+
return None
764+
last_part = state.parts[-1]
765+
if not last_part.function_call:
766+
return None
767+
delta_args = delta.arguments
768+
if delta_args is None or last_part.function_call.partial_args is None:
769+
return None
770+
last_part.function_call.partial_args.append(
771+
types.PartialArg(string_value=delta_args)
772+
)
773+
chunk_part = types.Part(
774+
function_call=types.FunctionCall(
775+
name=last_part.function_call.name,
776+
partial_args=[types.PartialArg(string_value=delta_args)],
777+
)
778+
)
779+
return _partial_part_response(chunk_part, interaction_id)
780+
781+
692782
def convert_interaction_event_to_llm_response(
693783
event: InteractionSSEEvent,
694-
aggregated_parts: list[types.Part],
784+
state: _StreamState,
695785
interaction_id: str | None = None,
696786
) -> LlmResponse | None:
697787
"""Convert an InteractionSSEEvent to an LlmResponse for streaming.
698788
699789
Args:
700790
event: The streaming event from interactions API.
701-
aggregated_parts: List to accumulate parts across events.
791+
state: Accumulates parts and grounding data across streamed events.
702792
interaction_id: The interaction ID to include in responses.
703793
704794
Returns:
@@ -718,7 +808,7 @@ def convert_interaction_event_to_llm_response(
718808
partial_args=[],
719809
)
720810
part = types.Part(function_call=fc)
721-
aggregated_parts.append(part)
811+
state.parts.append(part)
722812

723813
return LlmResponse(
724814
content=types.Content(role='model', parts=[part]),
@@ -729,75 +819,18 @@ def convert_interaction_event_to_llm_response(
729819

730820
elif isinstance(event, StepDelta):
731821
delta = event.delta
822+
delta_type = delta.type
732823

733-
if delta.type == 'text':
734-
text = delta.text
735-
if text:
736-
part = types.Part.from_text(text=text)
737-
aggregated_parts.append(part)
738-
return LlmResponse(
739-
content=types.Content(role='model', parts=[part]),
740-
partial=True,
741-
turn_complete=False,
742-
interaction_id=interaction_id,
743-
)
744-
745-
elif delta.type == 'image':
746-
data = delta.data
747-
uri = delta.uri
748-
mime_type = delta.mime_type
749-
if data or uri:
750-
if data:
751-
part = types.Part(
752-
inline_data=types.Blob(
753-
data=data,
754-
mime_type=mime_type,
755-
)
756-
)
757-
else:
758-
part = types.Part(
759-
file_data=types.FileData(
760-
file_uri=uri,
761-
mime_type=mime_type,
762-
)
763-
)
764-
aggregated_parts.append(part)
765-
return LlmResponse(
766-
content=types.Content(role='model', parts=[part]),
767-
partial=True,
768-
turn_complete=False,
769-
interaction_id=interaction_id,
770-
)
771-
772-
elif delta.type == 'arguments_delta':
773-
if aggregated_parts:
774-
last_part = aggregated_parts[-1]
775-
if last_part.function_call:
776-
delta_args = delta.arguments
777-
if (
778-
delta_args is not None
779-
and last_part.function_call.partial_args is not None
780-
):
781-
last_part.function_call.partial_args.append(
782-
types.PartialArg(string_value=delta_args)
783-
)
784-
785-
chunk_part = types.Part(
786-
function_call=types.FunctionCall(
787-
name=last_part.function_call.name,
788-
partial_args=[types.PartialArg(string_value=delta_args)],
789-
)
790-
)
791-
return LlmResponse(
792-
content=types.Content(role='model', parts=[chunk_part]),
793-
partial=True,
794-
turn_complete=False,
795-
interaction_id=interaction_id,
796-
)
824+
if delta_type == 'text':
825+
return _handle_text(delta, state, interaction_id)
826+
elif delta_type == 'image':
827+
return _handle_media(delta, state, interaction_id)
828+
elif delta_type == 'arguments_delta':
829+
return _handle_arguments_delta(delta, state, interaction_id)
797830

798831
elif isinstance(event, StepStop):
799-
if aggregated_parts and aggregated_parts[-1].function_call:
800-
fc = aggregated_parts[-1].function_call
832+
if state.parts and state.parts[-1].function_call:
833+
fc = state.parts[-1].function_call
801834
if fc.partial_args is not None:
802835
arg_str = ''.join(pa.string_value or '' for pa in fc.partial_args)
803836

@@ -828,9 +861,9 @@ def convert_interaction_event_to_llm_response(
828861

829862
elif isinstance(event, InteractionCompletedEvent):
830863
# Final aggregated response
831-
if aggregated_parts:
864+
if state.parts:
832865
return LlmResponse(
833-
content=types.Content(role='model', parts=aggregated_parts),
866+
content=types.Content(role='model', parts=state.parts),
834867
partial=False,
835868
turn_complete=True,
836869
finish_reason=types.FinishReason.STOP,
@@ -1193,14 +1226,14 @@ async def _create_interactions(
11931226
responses = await api_client.aio.interactions.create(
11941227
**create_kwargs, stream=True
11951228
)
1196-
aggregated_parts: list[types.Part] = []
1229+
state = _StreamState()
11971230
async for event in responses:
11981231
logger.debug(build_interactions_event_log(event))
11991232
interaction_id = _extract_stream_interaction_id(event)
12001233
if interaction_id:
12011234
current_interaction_id = interaction_id
12021235
llm_response = convert_interaction_event_to_llm_response(
1203-
event, aggregated_parts, current_interaction_id
1236+
event, state, current_interaction_id
12041237
)
12051238
if llm_response:
12061239
yield llm_response

0 commit comments

Comments
 (0)