Skip to content

Commit 1e85073

Browse files
authored
Merge branch 'main' into fix/3145-remotea2a-artifacts
2 parents 6201702 + 6a50b8d commit 1e85073

12 files changed

Lines changed: 693 additions & 124 deletions

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from google.genai import types
2727
from typing_extensions import override
2828

29-
from ..utils._google_client_headers import get_tracking_headers
3029
from ..utils.vertex_ai_utils import get_express_mode_api_key
3130
from .base_memory_service import BaseMemoryService
3231
from .base_memory_service import SearchMemoryResponse
@@ -617,17 +616,9 @@ def _get_api_client(self) -> vertexai.AsyncClient:
617616
"""
618617
import vertexai
619618

620-
http_options = types.HttpOptions(headers=get_tracking_headers())
621619
if self._express_mode_api_key:
622-
return vertexai.Client(
623-
http_options=http_options,
624-
api_key=self._express_mode_api_key,
625-
).aio
626-
return vertexai.Client(
627-
project=self._project,
628-
location=self._location,
629-
http_options=http_options,
630-
).aio
620+
return vertexai.Client(api_key=self._express_mode_api_key).aio
621+
return vertexai.Client(project=self._project, location=self._location).aio
631622

632623

633624
def _log_ingest_task_error(task: asyncio.Task) -> None:

src/google/adk/models/interactions_utils.py

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -706,13 +706,23 @@ def convert_interaction_to_llm_response(
706706

707707
@dataclasses.dataclass
708708
class _StreamState:
709-
"""Accumulates streamed parts across SSE events.
709+
"""Accumulates streamed parts and grounding data across SSE events.
710710
711711
``parts`` collects ``types.Part``s in arrival order to assemble the final
712-
``Content``.
712+
``Content``. The grounding fields accumulate google_search / citation data
713+
that maps to ``grounding_metadata`` (a top-level ``LlmResponse`` field, not a
714+
part) so it can be reattached to the final, persisted event.
713715
"""
714716

715717
parts: list[types.Part] = dataclasses.field(default_factory=list)
718+
web_search_queries: list[str] = dataclasses.field(default_factory=list)
719+
grounding_chunks: list[types.GroundingChunk] = dataclasses.field(
720+
default_factory=list
721+
)
722+
grounding_supports: list[types.GroundingSupport] = dataclasses.field(
723+
default_factory=list
724+
)
725+
search_entry_point: types.SearchEntryPoint | None = None
716726

717727

718728
def _partial_part_response(
@@ -727,6 +737,18 @@ def _partial_part_response(
727737
)
728738

729739

740+
def _partial_grounding_response(
741+
grounding_metadata: types.GroundingMetadata, interaction_id: str | None
742+
) -> LlmResponse:
743+
"""Build a partial streaming LlmResponse carrying incremental grounding."""
744+
return LlmResponse(
745+
grounding_metadata=grounding_metadata,
746+
partial=True,
747+
turn_complete=False,
748+
interaction_id=interaction_id,
749+
)
750+
751+
730752
def _handle_text(
731753
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
732754
) -> LlmResponse | None:
@@ -862,6 +884,69 @@ def _handle_code_execution_result(
862884
return _partial_part_response(part, interaction_id)
863885

864886

887+
def _handle_google_search_call(
888+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
889+
) -> LlmResponse | None:
890+
queries = delta.arguments.queries if delta.arguments else None
891+
if not queries:
892+
return None
893+
state.web_search_queries.extend(queries)
894+
grounding_metadata = types.GroundingMetadata(web_search_queries=list(queries))
895+
return _partial_grounding_response(grounding_metadata, interaction_id)
896+
897+
898+
def _handle_google_search_result(
899+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
900+
) -> LlmResponse | None:
901+
rendered = None
902+
for search_result in delta.result or []:
903+
if search_result.search_suggestions:
904+
rendered = search_result.search_suggestions
905+
break
906+
if not rendered:
907+
return None
908+
entry_point = types.SearchEntryPoint(rendered_content=rendered)
909+
state.search_entry_point = entry_point
910+
grounding_metadata = types.GroundingMetadata(search_entry_point=entry_point)
911+
return _partial_grounding_response(grounding_metadata, interaction_id)
912+
913+
914+
def _handle_text_annotation(
915+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
916+
) -> LlmResponse | None:
917+
new_chunks: list[types.GroundingChunk] = []
918+
new_supports: list[types.GroundingSupport] = []
919+
for annotation in delta.annotations or []:
920+
if getattr(annotation, 'type', None) != 'url_citation':
921+
continue
922+
chunk_index = len(state.grounding_chunks) + len(new_chunks)
923+
new_chunks.append(
924+
types.GroundingChunk(
925+
web=types.GroundingChunkWeb(
926+
uri=annotation.url, title=annotation.title
927+
)
928+
)
929+
)
930+
new_supports.append(
931+
types.GroundingSupport(
932+
segment=types.Segment(
933+
start_index=annotation.start_index,
934+
end_index=annotation.end_index,
935+
),
936+
grounding_chunk_indices=[chunk_index],
937+
)
938+
)
939+
if not new_chunks:
940+
return None
941+
state.grounding_chunks.extend(new_chunks)
942+
state.grounding_supports.extend(new_supports)
943+
grounding_metadata = types.GroundingMetadata(
944+
grounding_chunks=new_chunks,
945+
grounding_supports=new_supports,
946+
)
947+
return _partial_grounding_response(grounding_metadata, interaction_id)
948+
949+
865950
def _handle_function_result(
866951
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
867952
) -> LlmResponse | None:
@@ -875,6 +960,24 @@ def _handle_function_result(
875960
return _partial_part_response(part, interaction_id)
876961

877962

963+
def _build_grounding_metadata(
964+
state: _StreamState,
965+
) -> types.GroundingMetadata | None:
966+
if not (
967+
state.web_search_queries
968+
or state.grounding_chunks
969+
or state.grounding_supports
970+
or state.search_entry_point
971+
):
972+
return None
973+
return types.GroundingMetadata(
974+
web_search_queries=state.web_search_queries or None,
975+
grounding_chunks=state.grounding_chunks or None,
976+
grounding_supports=state.grounding_supports or None,
977+
search_entry_point=state.search_entry_point,
978+
)
979+
980+
878981
def convert_interaction_event_to_llm_response(
879982
event: InteractionSSEEvent,
880983
state: _StreamState,
@@ -931,6 +1034,12 @@ def convert_interaction_event_to_llm_response(
9311034
return _handle_code_execution_call(delta, state, interaction_id)
9321035
elif delta_type == 'code_execution_result':
9331036
return _handle_code_execution_result(delta, state, interaction_id)
1037+
elif delta_type == 'google_search_call':
1038+
return _handle_google_search_call(delta, state, interaction_id)
1039+
elif delta_type == 'google_search_result':
1040+
return _handle_google_search_result(delta, state, interaction_id)
1041+
elif delta_type == 'text_annotation_delta':
1042+
return _handle_text_annotation(delta, state, interaction_id)
9341043
elif delta_type == 'function_result':
9351044
return _handle_function_result(delta, state, interaction_id)
9361045
else:
@@ -968,16 +1077,23 @@ def convert_interaction_event_to_llm_response(
9681077
return None
9691078

9701079
elif isinstance(event, InteractionCompletedEvent):
971-
# Final aggregated response
972-
if state.parts:
1080+
grounding_metadata = _build_grounding_metadata(state)
1081+
if state.parts or grounding_metadata is not None:
1082+
content = (
1083+
types.Content(role='model', parts=state.parts)
1084+
if state.parts
1085+
else None
1086+
)
9731087
return LlmResponse(
974-
content=types.Content(role='model', parts=state.parts),
1088+
content=content,
1089+
grounding_metadata=grounding_metadata,
1090+
usage_metadata=_usage_metadata_from_interaction(event.interaction),
9751091
partial=False,
9761092
turn_complete=True,
9771093
finish_reason=types.FinishReason.STOP,
9781094
interaction_id=interaction_id,
9791095
)
980-
# If no streaming parts were collected, convert the final interaction directly
1096+
# No streaming parts or grounding collected: convert the final interaction.
9811097
return convert_interaction_to_llm_response(event.interaction)
9821098

9831099
elif isinstance(event, Interaction):

src/google/adk/models/lite_llm.py

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,30 @@ def _get_provider_from_model(model: str) -> str:
330330
return ""
331331

332332

333+
# Providers that can route to Anthropic. bedrock and vertex_ai are multi-model
334+
# platforms, so _is_anthropic_route also checks the model name for them.
335+
_ANTHROPIC_PROVIDERS = frozenset({"anthropic", "bedrock", "vertex_ai"})
336+
337+
338+
def _is_anthropic_provider(provider: str) -> bool:
339+
"""Returns True if the provider can route to an Anthropic model endpoint."""
340+
return provider.lower() in _ANTHROPIC_PROVIDERS if provider else False
341+
342+
343+
def _is_anthropic_route(provider: str, model: str) -> bool:
344+
"""Returns True only when requests actually reach an Anthropic Claude model.
345+
346+
bedrock and vertex_ai also host non-Anthropic models (Llama, Gemini), so for
347+
those platforms the model name must identify a Claude model too. Formatting
348+
thinking blocks for a non-Claude model triggers API validation (400) errors.
349+
"""
350+
if not _is_anthropic_provider(provider):
351+
return False
352+
if provider.lower() in ("bedrock", "vertex_ai"):
353+
return _is_anthropic_model(model)
354+
return True
355+
356+
333357
def _infer_mime_type_from_uri(uri: str) -> Optional[str]:
334358
"""Attempts to infer MIME type from a URI's path extension.
335359
@@ -491,42 +515,48 @@ def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]:
491515

492516

493517
def _is_thinking_blocks_format(reasoning_value: Any) -> bool:
494-
"""Returns True if reasoning_value is thinking_blocks format.
518+
"""Returns True if reasoning_value is Anthropic thinking_blocks format.
495519
496-
Anthropic blocks carry a 'signature'; Gemini blocks carry 'thinking'/'type'
497-
without one. Match either so Gemini thought text is not dropped.
520+
Anthropic thinking_blocks is a list of dicts, each with 'type', 'thinking',
521+
and 'signature' keys.
498522
"""
499523
if not isinstance(reasoning_value, list) or not reasoning_value:
500524
return False
501525
first = reasoning_value[0]
502-
return isinstance(first, dict) and (
503-
"thinking" in first or "signature" in first
504-
)
526+
return isinstance(first, dict) and "signature" in first
505527

506528

507529
def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]:
508530
"""Converts provider reasoning payloads into Gemini thought parts.
509531
510-
Handles Anthropic thinking_blocks (list of dicts with type/thinking/signature)
511-
by preserving the signature on each part's thought_signature field. This is
512-
required for Anthropic to maintain thinking across tool call boundaries.
532+
Handles two formats:
533+
- Anthropic thinking_blocks with 'thinking' and optional 'signature' fields.
534+
- A plain string or nested structure (OpenAI/Azure/Ollama) via
535+
_iter_reasoning_texts.
513536
"""
514-
if _is_thinking_blocks_format(reasoning_value):
537+
if isinstance(reasoning_value, list):
515538
parts: List[types.Part] = []
516539
for block in reasoning_value:
517-
if not isinstance(block, dict):
518-
continue
519-
block_type = block.get("type", "")
520-
if block_type == "redacted":
521-
continue
522-
thinking_text = block.get("thinking", "")
523-
signature = block.get("signature", "")
524-
if not thinking_text and not signature:
525-
continue
526-
part = types.Part(text=thinking_text, thought=True)
527-
if signature:
528-
part.thought_signature = signature.encode("utf-8")
529-
parts.append(part)
540+
if isinstance(block, dict):
541+
block_type = block.get("type", "")
542+
if block_type == "redacted":
543+
continue
544+
if block_type == "thinking":
545+
thinking_text = block.get("thinking", "")
546+
if thinking_text:
547+
part = types.Part(text=thinking_text, thought=True)
548+
signature = block.get("signature")
549+
if signature:
550+
decoded_signature = _decode_thought_signature(signature)
551+
part.thought_signature = decoded_signature or str(
552+
signature
553+
).encode("utf-8")
554+
parts.append(part)
555+
continue
556+
# Fall back to text extraction for non-thinking-block items.
557+
for text in _iter_reasoning_texts(block):
558+
if text:
559+
parts.append(types.Part(text=text, thought=True))
530560
return parts
531561
return [
532562
types.Part(text=text, thought=True)
@@ -538,16 +568,16 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]:
538568
def _extract_reasoning_value(message: Message | Delta | None) -> Any:
539569
"""Fetches the reasoning payload from a LiteLLM message.
540570
541-
Checks for 'thinking_blocks' (Anthropic structured format with signatures),
542-
'reasoning_content' (LiteLLM standard, used by Azure/Foundry, Ollama via
543-
LiteLLM) and 'reasoning' (used by LM Studio, vLLM).
544-
Prioritizes 'thinking_blocks' when present (Anthropic models), then
545-
'reasoning_content', then 'reasoning'.
571+
Checks for 'thinking_blocks' (Anthropic thinking with signatures),
572+
'reasoning_content' (LiteLLM standard, used by Azure/Foundry,
573+
Ollama via LiteLLM), and 'reasoning' (used by LM Studio, vLLM).
574+
Prioritizes 'thinking_blocks' when the key is present, as they contain
575+
the signature required for Anthropic's extended thinking API.
546576
"""
547577
if message is None:
548578
return None
549-
# Anthropic models return thinking_blocks with type/thinking/signature fields.
550-
# This must be preserved to maintain thinking across tool call boundaries.
579+
# Prefer thinking_blocks (Anthropic) — they carry per-block signatures
580+
# needed for multi-turn conversations with extended thinking.
551581
thinking_blocks = message.get("thinking_blocks")
552582
if thinking_blocks is not None:
553583
return thinking_blocks
@@ -999,7 +1029,7 @@ async def _content_to_message_param(
9991029
if part.text and part.thought_signature:
10001030
sig = part.thought_signature
10011031
if isinstance(sig, bytes):
1002-
sig = sig.decode("utf-8")
1032+
sig = base64.b64encode(sig).decode("utf-8")
10031033
thinking_blocks.append({
10041034
"type": "thinking",
10051035
"thinking": part.text,
@@ -1026,6 +1056,34 @@ async def _content_to_message_param(
10261056
):
10271057
reasoning_texts.append(_decode_inline_text_data(part.inline_data.data))
10281058

1059+
# Anthropic routes require thinking blocks to be embedded directly in the
1060+
# message content list. LiteLLM's prompt template for Anthropic drops the
1061+
# top-level reasoning_content field, so thinking blocks disappear from
1062+
# multi-turn histories and the model stops producing them after the first
1063+
# turn. Signatures are required by the Anthropic API for thinking blocks in
1064+
# multi-turn conversations. On multi-model platforms (bedrock, vertex_ai)
1065+
# this must only apply to actual Claude models, not Gemini/Llama/etc.
1066+
if reasoning_parts and _is_anthropic_route(provider, model):
1067+
content_list = []
1068+
for part in reasoning_parts:
1069+
if part.text:
1070+
block = {"type": "thinking", "thinking": part.text}
1071+
if part.thought_signature:
1072+
sig = part.thought_signature
1073+
if isinstance(sig, bytes):
1074+
sig = base64.b64encode(sig).decode("utf-8")
1075+
block["signature"] = sig
1076+
content_list.append(block)
1077+
if isinstance(final_content, list):
1078+
content_list.extend(final_content)
1079+
elif final_content:
1080+
content_list.append({"type": "text", "text": final_content})
1081+
return ChatCompletionAssistantMessage(
1082+
role=role,
1083+
content=content_list or None,
1084+
tool_calls=tool_calls or None,
1085+
)
1086+
10291087
# Preserve reasoning deltas exactly as received. Injecting separators
10301088
# between fragments can corrupt provider-streamed thinking text.
10311089
reasoning_content = "".join(text for text in reasoning_texts if text)

0 commit comments

Comments
 (0)