Skip to content

Commit a44b76d

Browse files
authored
Merge branch 'main' into fix/mcp-mtls-auth-header-case
2 parents f8eed05 + 6a50b8d commit a44b76d

27 files changed

Lines changed: 1310 additions & 220 deletions

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from google.genai import types
2727
from typing_extensions import override
2828

29-
from ..utils._google_client_headers import get_tracking_headers
3029
from ..utils.vertex_ai_utils import get_express_mode_api_key
3130
from .base_memory_service import BaseMemoryService
3231
from .base_memory_service import SearchMemoryResponse
@@ -617,17 +616,9 @@ def _get_api_client(self) -> vertexai.AsyncClient:
617616
"""
618617
import vertexai
619618

620-
http_options = types.HttpOptions(headers=get_tracking_headers())
621619
if self._express_mode_api_key:
622-
return vertexai.Client(
623-
http_options=http_options,
624-
api_key=self._express_mode_api_key,
625-
).aio
626-
return vertexai.Client(
627-
project=self._project,
628-
location=self._location,
629-
http_options=http_options,
630-
).aio
620+
return vertexai.Client(api_key=self._express_mode_api_key).aio
621+
return vertexai.Client(project=self._project, location=self._location).aio
631622

632623

633624
def _log_ingest_task_error(task: asyncio.Task) -> None:

src/google/adk/models/gemini_context_cache_manager.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,21 @@ async def _create_new_cache_with_contents(
326326
)
327327
return None
328328

329-
# Check client-side to avoid unnecessary API round-trips.
330-
if llm_request.cacheable_contents_token_count < _GEMINI_MIN_CACHE_TOKENS:
329+
# `cacheable_contents_token_count` is the token count of the whole previous
330+
# prompt (system instruction + tools + every content). The cache, however,
331+
# only stores the prefix `contents[:cache_contents_count]` plus the system
332+
# instruction and tools (see `_create_gemini_cache`). On a long conversation
333+
# the full-prompt count can clear Gemini's minimum while the cached prefix
334+
# is far smaller, which makes `caches.create` fail with 400
335+
# INVALID_ARGUMENT.
336+
# Gate on the estimated prefix size so we never send a sub-minimum payload.
337+
cacheable_prefix_tokens = self._estimate_cacheable_prefix_tokens(
338+
llm_request, cache_contents_count
339+
)
340+
if cacheable_prefix_tokens < _GEMINI_MIN_CACHE_TOKENS:
331341
logger.info(
332-
"Request below Gemini minimum cache size (%d < %d tokens)",
333-
llm_request.cacheable_contents_token_count,
342+
"Cacheable prefix below Gemini minimum cache size (%d < %d tokens)",
343+
cacheable_prefix_tokens,
334344
_GEMINI_MIN_CACHE_TOKENS,
335345
)
336346
return None
@@ -342,13 +352,20 @@ async def _create_new_cache_with_contents(
342352
logger.warning("Failed to create cache: %s", e)
343353
return None
344354

345-
def _estimate_request_tokens(self, llm_request: LlmRequest) -> int:
346-
"""Estimate token count for the request.
355+
def _estimate_request_tokens(
356+
self,
357+
llm_request: LlmRequest,
358+
cache_contents_count: Optional[int] = None,
359+
) -> int:
360+
"""Estimate token count for the request (or its cacheable prefix).
347361
348362
This is a rough estimation based on content text length.
349363
350364
Args:
351365
llm_request: Request to estimate tokens for
366+
cache_contents_count: When provided, only the first
367+
``cache_contents_count`` contents are counted (the prefix that gets
368+
cached); the system instruction and tools are always included.
352369
353370
Returns:
354371
Estimated token count
@@ -366,15 +383,54 @@ def _estimate_request_tokens(self, llm_request: LlmRequest) -> int:
366383
tool_str = json.dumps(tool.model_dump())
367384
total_chars += len(tool_str)
368385

369-
# Contents
370-
for content in llm_request.contents:
386+
# Contents (optionally limited to the cacheable prefix)
387+
contents = llm_request.contents
388+
if cache_contents_count is not None:
389+
contents = contents[:cache_contents_count]
390+
for content in contents:
371391
for part in content.parts:
372392
if part.text:
373393
total_chars += len(part.text)
374394

375395
# Rough estimate: 4 characters per token
376396
return total_chars // 4
377397

398+
def _estimate_cacheable_prefix_tokens(
399+
self, llm_request: LlmRequest, cache_contents_count: int
400+
) -> int:
401+
"""Estimate the token count of the prefix that will actually be cached.
402+
403+
The only accurate token count available is
404+
``cacheable_contents_token_count``, which covers the entire previous prompt.
405+
Since the cache stores just the prefix ``contents[:cache_contents_count]``
406+
(plus system instruction and tools), we scale that accurate count by the
407+
prefix's estimated share of the request. When the prefix already spans the
408+
whole request the scale factor is 1 and the accurate count is returned
409+
unchanged.
410+
411+
Args:
412+
llm_request: Request to estimate the cacheable prefix tokens for
413+
cache_contents_count: Number of leading contents that get cached
414+
415+
Returns:
416+
Estimated token count of the cacheable prefix
417+
"""
418+
full_tokens = llm_request.cacheable_contents_token_count
419+
if not full_tokens:
420+
return 0
421+
422+
full_estimate = self._estimate_request_tokens(llm_request)
423+
if full_estimate <= 0:
424+
# No text to estimate from (e.g. non-text parts); fall back to the
425+
# accurate full count rather than incorrectly skipping the cache.
426+
return full_tokens
427+
428+
prefix_estimate = self._estimate_request_tokens(
429+
llm_request, cache_contents_count
430+
)
431+
ratio = min(1.0, prefix_estimate / full_estimate)
432+
return int(full_tokens * ratio)
433+
378434
async def _create_gemini_cache(
379435
self, llm_request: LlmRequest, cache_contents_count: int
380436
) -> CacheMetadata:

0 commit comments

Comments
 (0)