Skip to content

Commit 6a50b8d

Browse files
haranrkcopybara-github
authored andcommitted
feat(interactions): surface streamed grounding and final usage metadata
Convert google_search_call/result and text_annotation_delta step deltas into grounding metadata, accumulate it in `_StreamState`, and reattach the aggregated grounding plus usage metadata to the final streaming event so they persist to the session and the Runner. Co-authored-by: Haran Rajkumar <haranrk@google.com> PiperOrigin-RevId: 938739994
1 parent 10a2cc5 commit 6a50b8d

2 files changed

Lines changed: 333 additions & 6 deletions

File tree

src/google/adk/models/interactions_utils.py

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -706,13 +706,23 @@ def convert_interaction_to_llm_response(
706706

707707
@dataclasses.dataclass
708708
class _StreamState:
709-
"""Accumulates streamed parts across SSE events.
709+
"""Accumulates streamed parts and grounding data across SSE events.
710710
711711
``parts`` collects ``types.Part``s in arrival order to assemble the final
712-
``Content``.
712+
``Content``. The grounding fields accumulate google_search / citation data
713+
that maps to ``grounding_metadata`` (a top-level ``LlmResponse`` field, not a
714+
part) so it can be reattached to the final, persisted event.
713715
"""
714716

715717
parts: list[types.Part] = dataclasses.field(default_factory=list)
718+
web_search_queries: list[str] = dataclasses.field(default_factory=list)
719+
grounding_chunks: list[types.GroundingChunk] = dataclasses.field(
720+
default_factory=list
721+
)
722+
grounding_supports: list[types.GroundingSupport] = dataclasses.field(
723+
default_factory=list
724+
)
725+
search_entry_point: types.SearchEntryPoint | None = None
716726

717727

718728
def _partial_part_response(
@@ -727,6 +737,18 @@ def _partial_part_response(
727737
)
728738

729739

740+
def _partial_grounding_response(
741+
grounding_metadata: types.GroundingMetadata, interaction_id: str | None
742+
) -> LlmResponse:
743+
"""Build a partial streaming LlmResponse carrying incremental grounding."""
744+
return LlmResponse(
745+
grounding_metadata=grounding_metadata,
746+
partial=True,
747+
turn_complete=False,
748+
interaction_id=interaction_id,
749+
)
750+
751+
730752
def _handle_text(
731753
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
732754
) -> LlmResponse | None:
@@ -862,6 +884,69 @@ def _handle_code_execution_result(
862884
return _partial_part_response(part, interaction_id)
863885

864886

887+
def _handle_google_search_call(
888+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
889+
) -> LlmResponse | None:
890+
queries = delta.arguments.queries if delta.arguments else None
891+
if not queries:
892+
return None
893+
state.web_search_queries.extend(queries)
894+
grounding_metadata = types.GroundingMetadata(web_search_queries=list(queries))
895+
return _partial_grounding_response(grounding_metadata, interaction_id)
896+
897+
898+
def _handle_google_search_result(
899+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
900+
) -> LlmResponse | None:
901+
rendered = None
902+
for search_result in delta.result or []:
903+
if search_result.search_suggestions:
904+
rendered = search_result.search_suggestions
905+
break
906+
if not rendered:
907+
return None
908+
entry_point = types.SearchEntryPoint(rendered_content=rendered)
909+
state.search_entry_point = entry_point
910+
grounding_metadata = types.GroundingMetadata(search_entry_point=entry_point)
911+
return _partial_grounding_response(grounding_metadata, interaction_id)
912+
913+
914+
def _handle_text_annotation(
915+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
916+
) -> LlmResponse | None:
917+
new_chunks: list[types.GroundingChunk] = []
918+
new_supports: list[types.GroundingSupport] = []
919+
for annotation in delta.annotations or []:
920+
if getattr(annotation, 'type', None) != 'url_citation':
921+
continue
922+
chunk_index = len(state.grounding_chunks) + len(new_chunks)
923+
new_chunks.append(
924+
types.GroundingChunk(
925+
web=types.GroundingChunkWeb(
926+
uri=annotation.url, title=annotation.title
927+
)
928+
)
929+
)
930+
new_supports.append(
931+
types.GroundingSupport(
932+
segment=types.Segment(
933+
start_index=annotation.start_index,
934+
end_index=annotation.end_index,
935+
),
936+
grounding_chunk_indices=[chunk_index],
937+
)
938+
)
939+
if not new_chunks:
940+
return None
941+
state.grounding_chunks.extend(new_chunks)
942+
state.grounding_supports.extend(new_supports)
943+
grounding_metadata = types.GroundingMetadata(
944+
grounding_chunks=new_chunks,
945+
grounding_supports=new_supports,
946+
)
947+
return _partial_grounding_response(grounding_metadata, interaction_id)
948+
949+
865950
def _handle_function_result(
866951
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
867952
) -> LlmResponse | None:
@@ -875,6 +960,24 @@ def _handle_function_result(
875960
return _partial_part_response(part, interaction_id)
876961

877962

963+
def _build_grounding_metadata(
964+
state: _StreamState,
965+
) -> types.GroundingMetadata | None:
966+
if not (
967+
state.web_search_queries
968+
or state.grounding_chunks
969+
or state.grounding_supports
970+
or state.search_entry_point
971+
):
972+
return None
973+
return types.GroundingMetadata(
974+
web_search_queries=state.web_search_queries or None,
975+
grounding_chunks=state.grounding_chunks or None,
976+
grounding_supports=state.grounding_supports or None,
977+
search_entry_point=state.search_entry_point,
978+
)
979+
980+
878981
def convert_interaction_event_to_llm_response(
879982
event: InteractionSSEEvent,
880983
state: _StreamState,
@@ -931,6 +1034,12 @@ def convert_interaction_event_to_llm_response(
9311034
return _handle_code_execution_call(delta, state, interaction_id)
9321035
elif delta_type == 'code_execution_result':
9331036
return _handle_code_execution_result(delta, state, interaction_id)
1037+
elif delta_type == 'google_search_call':
1038+
return _handle_google_search_call(delta, state, interaction_id)
1039+
elif delta_type == 'google_search_result':
1040+
return _handle_google_search_result(delta, state, interaction_id)
1041+
elif delta_type == 'text_annotation_delta':
1042+
return _handle_text_annotation(delta, state, interaction_id)
9341043
elif delta_type == 'function_result':
9351044
return _handle_function_result(delta, state, interaction_id)
9361045
else:
@@ -968,16 +1077,23 @@ def convert_interaction_event_to_llm_response(
9681077
return None
9691078

9701079
elif isinstance(event, InteractionCompletedEvent):
971-
# Final aggregated response
972-
if state.parts:
1080+
grounding_metadata = _build_grounding_metadata(state)
1081+
if state.parts or grounding_metadata is not None:
1082+
content = (
1083+
types.Content(role='model', parts=state.parts)
1084+
if state.parts
1085+
else None
1086+
)
9731087
return LlmResponse(
974-
content=types.Content(role='model', parts=state.parts),
1088+
content=content,
1089+
grounding_metadata=grounding_metadata,
1090+
usage_metadata=_usage_metadata_from_interaction(event.interaction),
9751091
partial=False,
9761092
turn_complete=True,
9771093
finish_reason=types.FinishReason.STOP,
9781094
interaction_id=interaction_id,
9791095
)
980-
# If no streaming parts were collected, convert the final interaction directly
1096+
# No streaming parts or grounding collected: convert the final interaction.
9811097
return convert_interaction_to_llm_response(event.interaction)
9821098

9831099
elif isinstance(event, Interaction):

0 commit comments

Comments
 (0)