Skip to content

Commit 086b57e

Browse files
authored
Merge branch 'main' into fix-thought-signature-pruning
2 parents a4a9d37 + 6a50b8d commit 086b57e

27 files changed

Lines changed: 992 additions & 219 deletions

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from google.genai import types
2727
from typing_extensions import override
2828

29-
from ..utils._google_client_headers import get_tracking_headers
3029
from ..utils.vertex_ai_utils import get_express_mode_api_key
3130
from .base_memory_service import BaseMemoryService
3231
from .base_memory_service import SearchMemoryResponse
@@ -617,17 +616,9 @@ def _get_api_client(self) -> vertexai.AsyncClient:
617616
"""
618617
import vertexai
619618

620-
http_options = types.HttpOptions(headers=get_tracking_headers())
621619
if self._express_mode_api_key:
622-
return vertexai.Client(
623-
http_options=http_options,
624-
api_key=self._express_mode_api_key,
625-
).aio
626-
return vertexai.Client(
627-
project=self._project,
628-
location=self._location,
629-
http_options=http_options,
630-
).aio
620+
return vertexai.Client(api_key=self._express_mode_api_key).aio
621+
return vertexai.Client(project=self._project, location=self._location).aio
631622

632623

633624
def _log_ingest_task_error(task: asyncio.Task) -> None:

src/google/adk/models/gemini_context_cache_manager.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,21 @@ async def _create_new_cache_with_contents(
326326
)
327327
return None
328328

329-
# Check client-side to avoid unnecessary API round-trips.
330-
if llm_request.cacheable_contents_token_count < _GEMINI_MIN_CACHE_TOKENS:
329+
# `cacheable_contents_token_count` is the token count of the whole previous
330+
# prompt (system instruction + tools + every content). The cache, however,
331+
# only stores the prefix `contents[:cache_contents_count]` plus the system
332+
# instruction and tools (see `_create_gemini_cache`). On a long conversation
333+
# the full-prompt count can clear Gemini's minimum while the cached prefix
334+
# is far smaller, which makes `caches.create` fail with 400
335+
# INVALID_ARGUMENT.
336+
# Gate on the estimated prefix size so we never send a sub-minimum payload.
337+
cacheable_prefix_tokens = self._estimate_cacheable_prefix_tokens(
338+
llm_request, cache_contents_count
339+
)
340+
if cacheable_prefix_tokens < _GEMINI_MIN_CACHE_TOKENS:
331341
logger.info(
332-
"Request below Gemini minimum cache size (%d < %d tokens)",
333-
llm_request.cacheable_contents_token_count,
342+
"Cacheable prefix below Gemini minimum cache size (%d < %d tokens)",
343+
cacheable_prefix_tokens,
334344
_GEMINI_MIN_CACHE_TOKENS,
335345
)
336346
return None
@@ -342,13 +352,20 @@ async def _create_new_cache_with_contents(
342352
logger.warning("Failed to create cache: %s", e)
343353
return None
344354

345-
def _estimate_request_tokens(self, llm_request: LlmRequest) -> int:
346-
"""Estimate token count for the request.
355+
def _estimate_request_tokens(
356+
self,
357+
llm_request: LlmRequest,
358+
cache_contents_count: Optional[int] = None,
359+
) -> int:
360+
"""Estimate token count for the request (or its cacheable prefix).
347361
348362
This is a rough estimation based on content text length.
349363
350364
Args:
351365
llm_request: Request to estimate tokens for
366+
cache_contents_count: When provided, only the first
367+
``cache_contents_count`` contents are counted (the prefix that gets
368+
cached); the system instruction and tools are always included.
352369
353370
Returns:
354371
Estimated token count
@@ -366,15 +383,54 @@ def _estimate_request_tokens(self, llm_request: LlmRequest) -> int:
366383
tool_str = json.dumps(tool.model_dump())
367384
total_chars += len(tool_str)
368385

369-
# Contents
370-
for content in llm_request.contents:
386+
# Contents (optionally limited to the cacheable prefix)
387+
contents = llm_request.contents
388+
if cache_contents_count is not None:
389+
contents = contents[:cache_contents_count]
390+
for content in contents:
371391
for part in content.parts:
372392
if part.text:
373393
total_chars += len(part.text)
374394

375395
# Rough estimate: 4 characters per token
376396
return total_chars // 4
377397

398+
def _estimate_cacheable_prefix_tokens(
399+
self, llm_request: LlmRequest, cache_contents_count: int
400+
) -> int:
401+
"""Estimate the token count of the prefix that will actually be cached.
402+
403+
The only accurate token count available is
404+
``cacheable_contents_token_count``, which covers the entire previous prompt.
405+
Since the cache stores just the prefix ``contents[:cache_contents_count]``
406+
(plus system instruction and tools), we scale that accurate count by the
407+
prefix's estimated share of the request. When the prefix already spans the
408+
whole request the scale factor is 1 and the accurate count is returned
409+
unchanged.
410+
411+
Args:
412+
llm_request: Request to estimate the cacheable prefix tokens for
413+
cache_contents_count: Number of leading contents that get cached
414+
415+
Returns:
416+
Estimated token count of the cacheable prefix
417+
"""
418+
full_tokens = llm_request.cacheable_contents_token_count
419+
if not full_tokens:
420+
return 0
421+
422+
full_estimate = self._estimate_request_tokens(llm_request)
423+
if full_estimate <= 0:
424+
# No text to estimate from (e.g. non-text parts); fall back to the
425+
# accurate full count rather than incorrectly skipping the cache.
426+
return full_tokens
427+
428+
prefix_estimate = self._estimate_request_tokens(
429+
llm_request, cache_contents_count
430+
)
431+
ratio = min(1.0, prefix_estimate / full_estimate)
432+
return int(full_tokens * ratio)
433+
378434
async def _create_gemini_cache(
379435
self, llm_request: LlmRequest, cache_contents_count: int
380436
) -> CacheMetadata:

src/google/adk/models/interactions_utils.py

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -706,13 +706,23 @@ def convert_interaction_to_llm_response(
706706

707707
@dataclasses.dataclass
708708
class _StreamState:
709-
"""Accumulates streamed parts across SSE events.
709+
"""Accumulates streamed parts and grounding data across SSE events.
710710
711711
``parts`` collects ``types.Part``s in arrival order to assemble the final
712-
``Content``.
712+
``Content``. The grounding fields accumulate google_search / citation data
713+
that maps to ``grounding_metadata`` (a top-level ``LlmResponse`` field, not a
714+
part) so it can be reattached to the final, persisted event.
713715
"""
714716

715717
parts: list[types.Part] = dataclasses.field(default_factory=list)
718+
web_search_queries: list[str] = dataclasses.field(default_factory=list)
719+
grounding_chunks: list[types.GroundingChunk] = dataclasses.field(
720+
default_factory=list
721+
)
722+
grounding_supports: list[types.GroundingSupport] = dataclasses.field(
723+
default_factory=list
724+
)
725+
search_entry_point: types.SearchEntryPoint | None = None
716726

717727

718728
def _partial_part_response(
@@ -727,6 +737,18 @@ def _partial_part_response(
727737
)
728738

729739

740+
def _partial_grounding_response(
741+
grounding_metadata: types.GroundingMetadata, interaction_id: str | None
742+
) -> LlmResponse:
743+
"""Build a partial streaming LlmResponse carrying incremental grounding."""
744+
return LlmResponse(
745+
grounding_metadata=grounding_metadata,
746+
partial=True,
747+
turn_complete=False,
748+
interaction_id=interaction_id,
749+
)
750+
751+
730752
def _handle_text(
731753
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
732754
) -> LlmResponse | None:
@@ -862,6 +884,69 @@ def _handle_code_execution_result(
862884
return _partial_part_response(part, interaction_id)
863885

864886

887+
def _handle_google_search_call(
888+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
889+
) -> LlmResponse | None:
890+
queries = delta.arguments.queries if delta.arguments else None
891+
if not queries:
892+
return None
893+
state.web_search_queries.extend(queries)
894+
grounding_metadata = types.GroundingMetadata(web_search_queries=list(queries))
895+
return _partial_grounding_response(grounding_metadata, interaction_id)
896+
897+
898+
def _handle_google_search_result(
899+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
900+
) -> LlmResponse | None:
901+
rendered = None
902+
for search_result in delta.result or []:
903+
if search_result.search_suggestions:
904+
rendered = search_result.search_suggestions
905+
break
906+
if not rendered:
907+
return None
908+
entry_point = types.SearchEntryPoint(rendered_content=rendered)
909+
state.search_entry_point = entry_point
910+
grounding_metadata = types.GroundingMetadata(search_entry_point=entry_point)
911+
return _partial_grounding_response(grounding_metadata, interaction_id)
912+
913+
914+
def _handle_text_annotation(
915+
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
916+
) -> LlmResponse | None:
917+
new_chunks: list[types.GroundingChunk] = []
918+
new_supports: list[types.GroundingSupport] = []
919+
for annotation in delta.annotations or []:
920+
if getattr(annotation, 'type', None) != 'url_citation':
921+
continue
922+
chunk_index = len(state.grounding_chunks) + len(new_chunks)
923+
new_chunks.append(
924+
types.GroundingChunk(
925+
web=types.GroundingChunkWeb(
926+
uri=annotation.url, title=annotation.title
927+
)
928+
)
929+
)
930+
new_supports.append(
931+
types.GroundingSupport(
932+
segment=types.Segment(
933+
start_index=annotation.start_index,
934+
end_index=annotation.end_index,
935+
),
936+
grounding_chunk_indices=[chunk_index],
937+
)
938+
)
939+
if not new_chunks:
940+
return None
941+
state.grounding_chunks.extend(new_chunks)
942+
state.grounding_supports.extend(new_supports)
943+
grounding_metadata = types.GroundingMetadata(
944+
grounding_chunks=new_chunks,
945+
grounding_supports=new_supports,
946+
)
947+
return _partial_grounding_response(grounding_metadata, interaction_id)
948+
949+
865950
def _handle_function_result(
866951
delta: StepDeltaData, state: _StreamState, interaction_id: str | None
867952
) -> LlmResponse | None:
@@ -875,6 +960,24 @@ def _handle_function_result(
875960
return _partial_part_response(part, interaction_id)
876961

877962

963+
def _build_grounding_metadata(
964+
state: _StreamState,
965+
) -> types.GroundingMetadata | None:
966+
if not (
967+
state.web_search_queries
968+
or state.grounding_chunks
969+
or state.grounding_supports
970+
or state.search_entry_point
971+
):
972+
return None
973+
return types.GroundingMetadata(
974+
web_search_queries=state.web_search_queries or None,
975+
grounding_chunks=state.grounding_chunks or None,
976+
grounding_supports=state.grounding_supports or None,
977+
search_entry_point=state.search_entry_point,
978+
)
979+
980+
878981
def convert_interaction_event_to_llm_response(
879982
event: InteractionSSEEvent,
880983
state: _StreamState,
@@ -931,6 +1034,12 @@ def convert_interaction_event_to_llm_response(
9311034
return _handle_code_execution_call(delta, state, interaction_id)
9321035
elif delta_type == 'code_execution_result':
9331036
return _handle_code_execution_result(delta, state, interaction_id)
1037+
elif delta_type == 'google_search_call':
1038+
return _handle_google_search_call(delta, state, interaction_id)
1039+
elif delta_type == 'google_search_result':
1040+
return _handle_google_search_result(delta, state, interaction_id)
1041+
elif delta_type == 'text_annotation_delta':
1042+
return _handle_text_annotation(delta, state, interaction_id)
9341043
elif delta_type == 'function_result':
9351044
return _handle_function_result(delta, state, interaction_id)
9361045
else:
@@ -968,16 +1077,23 @@ def convert_interaction_event_to_llm_response(
9681077
return None
9691078

9701079
elif isinstance(event, InteractionCompletedEvent):
971-
# Final aggregated response
972-
if state.parts:
1080+
grounding_metadata = _build_grounding_metadata(state)
1081+
if state.parts or grounding_metadata is not None:
1082+
content = (
1083+
types.Content(role='model', parts=state.parts)
1084+
if state.parts
1085+
else None
1086+
)
9731087
return LlmResponse(
974-
content=types.Content(role='model', parts=state.parts),
1088+
content=content,
1089+
grounding_metadata=grounding_metadata,
1090+
usage_metadata=_usage_metadata_from_interaction(event.interaction),
9751091
partial=False,
9761092
turn_complete=True,
9771093
finish_reason=types.FinishReason.STOP,
9781094
interaction_id=interaction_id,
9791095
)
980-
# If no streaming parts were collected, convert the final interaction directly
1096+
# No streaming parts or grounding collected: convert the final interaction.
9811097
return convert_interaction_to_llm_response(event.interaction)
9821098

9831099
elif isinstance(event, Interaction):

0 commit comments

Comments
 (0)