3131from __future__ import annotations
3232
3333import base64
34+ import dataclasses
3435import json
3536import logging
3637from typing import Any
6364from google .genai .interactions import ModelOutputStepParam
6465from google .genai .interactions import Step
6566from google .genai .interactions import StepDelta
67+ from google .genai .interactions import StepDeltaData
6668from google .genai .interactions import StepParam
6769from google .genai .interactions import StepStart
6870from google .genai .interactions import StepStop
@@ -618,6 +620,28 @@ def _convert_interaction_step_to_parts(step: Step) -> list[types.Part]:
618620 return []
619621
620622
623+ def _usage_metadata_from_interaction (
624+ interaction : Interaction ,
625+ ) -> types .GenerateContentResponseUsageMetadata | None :
626+ """Build usage metadata from an interaction's usage, if present.
627+
628+ Shared by the non-streaming converter and the streaming final-event branch so
629+ both surface token counts identically. ``InteractionSseEventInteraction`` (the
630+ type carried by ``InteractionCompletedEvent``) also exposes ``usage``, so this
631+ accepts either interaction type.
632+ """
633+ if not interaction .usage :
634+ return None
635+ return types .GenerateContentResponseUsageMetadata (
636+ prompt_token_count = interaction .usage .total_input_tokens ,
637+ candidates_token_count = interaction .usage .total_output_tokens ,
638+ total_token_count = (
639+ (interaction .usage .total_input_tokens or 0 )
640+ + (interaction .usage .total_output_tokens or 0 )
641+ ),
642+ )
643+
644+
621645def convert_interaction_to_llm_response (
622646 interaction : Interaction ,
623647) -> LlmResponse :
@@ -658,17 +682,7 @@ def convert_interaction_to_llm_response(
658682 if parts :
659683 content = types .Content (role = 'model' , parts = parts )
660684
661- # Convert usage metadata if available
662- usage_metadata = None
663- if interaction .usage :
664- usage_metadata = types .GenerateContentResponseUsageMetadata (
665- prompt_token_count = interaction .usage .total_input_tokens ,
666- candidates_token_count = interaction .usage .total_output_tokens ,
667- total_token_count = (
668- (interaction .usage .total_input_tokens or 0 )
669- + (interaction .usage .total_output_tokens or 0 )
670- ),
671- )
685+ usage_metadata = _usage_metadata_from_interaction (interaction )
672686
673687 # Determine finish reason based on status.
674688 # Interaction status can be: 'completed', 'requires_action', 'failed', or
@@ -689,16 +703,92 @@ def convert_interaction_to_llm_response(
689703 )
690704
691705
706+ @dataclasses .dataclass
707+ class _StreamState :
708+ """Accumulates streamed parts across SSE events.
709+
710+ ``parts`` collects ``types.Part``s in arrival order to assemble the final
711+ ``Content``.
712+ """
713+
714+ parts : list [types .Part ] = dataclasses .field (default_factory = list )
715+
716+
717+ def _partial_part_response (
718+ part : types .Part , interaction_id : str | None
719+ ) -> LlmResponse :
720+ """Build a partial streaming LlmResponse carrying a single content part."""
721+ return LlmResponse (
722+ content = types .Content (role = 'model' , parts = [part ]),
723+ partial = True ,
724+ turn_complete = False ,
725+ interaction_id = interaction_id ,
726+ )
727+
728+
729+ def _handle_text (
730+ delta : StepDeltaData , state : _StreamState , interaction_id : str | None
731+ ) -> LlmResponse | None :
732+ text = delta .text
733+ if not text :
734+ return None
735+ part = types .Part .from_text (text = text )
736+ state .parts .append (part )
737+ return _partial_part_response (part , interaction_id )
738+
739+
740+ def _handle_media (
741+ delta : StepDeltaData , state : _StreamState , interaction_id : str | None
742+ ) -> LlmResponse | None :
743+ """Handle image/audio/video/document deltas (shared data/uri/mime_type)."""
744+ data = delta .data
745+ uri = delta .uri
746+ mime_type = delta .mime_type
747+ if not data and not uri :
748+ return None
749+ if data :
750+ part = types .Part (inline_data = types .Blob (data = data , mime_type = mime_type ))
751+ else :
752+ part = types .Part (
753+ file_data = types .FileData (file_uri = uri , mime_type = mime_type )
754+ )
755+ state .parts .append (part )
756+ return _partial_part_response (part , interaction_id )
757+
758+
759+ def _handle_arguments_delta (
760+ delta : StepDeltaData , state : _StreamState , interaction_id : str | None
761+ ) -> LlmResponse | None :
762+ if not state .parts :
763+ return None
764+ last_part = state .parts [- 1 ]
765+ if not last_part .function_call :
766+ return None
767+ delta_args = delta .arguments
768+ if delta_args is None or last_part .function_call .partial_args is None :
769+ return None
770+ last_part .function_call .partial_args .append (
771+ types .PartialArg (string_value = delta_args )
772+ )
773+ chunk_part = types .Part (
774+ function_call = types .FunctionCall (
775+ name = last_part .function_call .name ,
776+ partial_args = [types .PartialArg (string_value = delta_args )],
777+ )
778+ )
779+ return _partial_part_response (chunk_part , interaction_id )
780+
781+
692782def convert_interaction_event_to_llm_response (
693783 event : InteractionSSEEvent ,
694- aggregated_parts : list [ types . Part ] ,
784+ state : _StreamState ,
695785 interaction_id : str | None = None ,
696786) -> LlmResponse | None :
697787 """Convert an InteractionSSEEvent to an LlmResponse for streaming.
698788
699789 Args:
700790 event: The streaming event from interactions API.
701- aggregated_parts: List to accumulate parts across events.
791+ state: Accumulates parts and grounding data across streamed events.
702792 interaction_id: The interaction ID to include in responses.
703793
704794 Returns:
@@ -718,7 +808,7 @@ def convert_interaction_event_to_llm_response(
718808 partial_args = [],
719809 )
720810 part = types .Part (function_call = fc )
721- aggregated_parts .append (part )
811+ state . parts .append (part )
722812
723813 return LlmResponse (
724814 content = types .Content (role = 'model' , parts = [part ]),
@@ -729,75 +819,18 @@ def convert_interaction_event_to_llm_response(
729819
730820 elif isinstance (event , StepDelta ):
731821 delta = event .delta
822+ delta_type = delta .type
732823
733- if delta .type == 'text' :
734- text = delta .text
735- if text :
736- part = types .Part .from_text (text = text )
737- aggregated_parts .append (part )
738- return LlmResponse (
739- content = types .Content (role = 'model' , parts = [part ]),
740- partial = True ,
741- turn_complete = False ,
742- interaction_id = interaction_id ,
743- )
744-
745- elif delta .type == 'image' :
746- data = delta .data
747- uri = delta .uri
748- mime_type = delta .mime_type
749- if data or uri :
750- if data :
751- part = types .Part (
752- inline_data = types .Blob (
753- data = data ,
754- mime_type = mime_type ,
755- )
756- )
757- else :
758- part = types .Part (
759- file_data = types .FileData (
760- file_uri = uri ,
761- mime_type = mime_type ,
762- )
763- )
764- aggregated_parts .append (part )
765- return LlmResponse (
766- content = types .Content (role = 'model' , parts = [part ]),
767- partial = True ,
768- turn_complete = False ,
769- interaction_id = interaction_id ,
770- )
771-
772- elif delta .type == 'arguments_delta' :
773- if aggregated_parts :
774- last_part = aggregated_parts [- 1 ]
775- if last_part .function_call :
776- delta_args = delta .arguments
777- if (
778- delta_args is not None
779- and last_part .function_call .partial_args is not None
780- ):
781- last_part .function_call .partial_args .append (
782- types .PartialArg (string_value = delta_args )
783- )
784-
785- chunk_part = types .Part (
786- function_call = types .FunctionCall (
787- name = last_part .function_call .name ,
788- partial_args = [types .PartialArg (string_value = delta_args )],
789- )
790- )
791- return LlmResponse (
792- content = types .Content (role = 'model' , parts = [chunk_part ]),
793- partial = True ,
794- turn_complete = False ,
795- interaction_id = interaction_id ,
796- )
824+ if delta_type == 'text' :
825+ return _handle_text (delta , state , interaction_id )
826+ elif delta_type == 'image' :
827+ return _handle_media (delta , state , interaction_id )
828+ elif delta_type == 'arguments_delta' :
829+ return _handle_arguments_delta (delta , state , interaction_id )
797830
798831 elif isinstance (event , StepStop ):
799- if aggregated_parts and aggregated_parts [- 1 ].function_call :
800- fc = aggregated_parts [- 1 ].function_call
832+ if state . parts and state . parts [- 1 ].function_call :
833+ fc = state . parts [- 1 ].function_call
801834 if fc .partial_args is not None :
802835 arg_str = '' .join (pa .string_value or '' for pa in fc .partial_args )
803836
@@ -828,9 +861,9 @@ def convert_interaction_event_to_llm_response(
828861
829862 elif isinstance (event , InteractionCompletedEvent ):
830863 # Final aggregated response
831- if aggregated_parts :
864+ if state . parts :
832865 return LlmResponse (
833- content = types .Content (role = 'model' , parts = aggregated_parts ),
866+ content = types .Content (role = 'model' , parts = state . parts ),
834867 partial = False ,
835868 turn_complete = True ,
836869 finish_reason = types .FinishReason .STOP ,
@@ -1193,14 +1226,14 @@ async def _create_interactions(
11931226 responses = await api_client .aio .interactions .create (
11941227 ** create_kwargs , stream = True
11951228 )
1196- aggregated_parts : list [ types . Part ] = []
1229+ state = _StreamState ()
11971230 async for event in responses :
11981231 logger .debug (build_interactions_event_log (event ))
11991232 interaction_id = _extract_stream_interaction_id (event )
12001233 if interaction_id :
12011234 current_interaction_id = interaction_id
12021235 llm_response = convert_interaction_event_to_llm_response (
1203- event , aggregated_parts , current_interaction_id
1236+ event , state , current_interaction_id
12041237 )
12051238 if llm_response :
12061239 yield llm_response
0 commit comments