Skip to content
13 changes: 12 additions & 1 deletion src/google/adk/agents/context_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ class ContextCacheConfig(BaseModel):
),
)

async_creation: bool = Field(
default=False,
description=(
"When True, cache creation is performed in the background instead of"
" blocking the current request. The current request proceeds uncached"
" and the cache is available for the next request. This eliminates"
" latency spikes from slow CachedContent.create() API calls."
),
)

@property
def ttl_string(self) -> str:
"""Get TTL as string format for cache creation."""
Expand All @@ -81,5 +91,6 @@ def __str__(self) -> str:
"""String representation for logging."""
return (
f"ContextCacheConfig(cache_intervals={self.cache_intervals}, "
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens})"
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens}, "
f"async_creation={self.async_creation})"
)
245 changes: 235 additions & 10 deletions src/google/adk/models/gemini_context_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
Expand All @@ -32,6 +33,72 @@

logger = logging.getLogger("google_adk." + __name__)

# Background cache task registry for async_creation mode.
# Key: (model, fingerprint, contents_count)
# Value: asyncio.Task that resolves to Optional[CacheMetadata]
_pending_cache_tasks: dict[tuple[str, str, int], asyncio.Task] = {}


def _cache_task_key(
model: str, fingerprint: str, contents_count: int
) -> tuple[str, str, int]:
"""Build a registry key for a pending cache task."""
return (model, fingerprint, contents_count)


def _check_pending_cache(
key: tuple[str, str, int],
) -> Optional[CacheMetadata]:
"""Check if a background cache task completed successfully.

Returns the CacheMetadata if the task is done and succeeded,
None if the task is still running or failed.
Cleans up the registry entry in either done case.
"""
task = _pending_cache_tasks.get(key)
if task is None:
return None

if not task.done():
return None

# Task is done - remove from registry regardless of outcome
del _pending_cache_tasks[key]

if task.cancelled():
logger.warning("Background cache task was cancelled for key %s", key)
return None

exc = task.exception()
if exc is not None:
logger.warning("Background cache task failed for key %s: %s", key, exc)
return None

return task.result()


def _cleanup_stale_tasks() -> None:
"""Remove completed tasks that have been sitting unclaimed."""
to_remove = []
for key, task in _pending_cache_tasks.items():
if not task.done():
continue
result = None
try:
if not task.cancelled():
exc = task.exception()
if exc is None:
result = task.result()
except Exception:
pass
if result is None or (
result.expire_time is not None and time.time() >= result.expire_time
):
to_remove.append(key)
for key in to_remove:
del _pending_cache_tasks[key]
Comment thread
abhinavmaddineni marked this conversation as resolved.
Outdated


if TYPE_CHECKING:
from google.genai import Client

Expand Down Expand Up @@ -62,13 +129,55 @@ async def handle_context_caching(
the cache to the request by setting cached_content and removing cached
contents from the request.

When async_creation is enabled in the cache config, cache creation is
performed in the background instead of blocking the current request.

Args:
llm_request: Request that may contain cache config and metadata.
Modified in-place to use the cache.

Returns:
Cache metadata to be included in response, or None if caching failed
"""
async_creation = (
llm_request.cache_config
and llm_request.cache_config.async_creation
)

# Opportunistically clean up stale background tasks
if async_creation:
_cleanup_stale_tasks()

# Check for completed background cache creation (async_creation mode)
if (
async_creation
and llm_request.cache_metadata
and llm_request.cache_metadata.cache_name is None
):
fp = llm_request.cache_metadata.fingerprint
cc = llm_request.cache_metadata.contents_count
model = llm_request.model or ""
key = _cache_task_key(model, fp, cc)
bg_result = _check_pending_cache(key)
if bg_result and bg_result.cache_name:
if time.time() < bg_result.expire_time:
logger.info(
"Using background-created cache: %s",
bg_result.cache_name,
)
self._apply_cache_to_request(
llm_request,
bg_result.cache_name,
bg_result.contents_count,
)
return bg_result
else:
logger.info(
"Background-created cache already expired: %s",
bg_result.cache_name,
)
await self.cleanup_cache(bg_result.cache_name)

# Check if we have existing cache metadata and if it's valid
if llm_request.cache_metadata:
logger.debug(
Expand Down Expand Up @@ -107,17 +216,37 @@ async def handle_context_caching(

# If fingerprints match, create new cache (expired but same content)
if current_fingerprint == old_cache_metadata.fingerprint:
logger.debug(
"Fingerprints match after invalidation, creating new cache"
)
cache_metadata = await self._create_new_cache_with_contents(
llm_request, cache_contents_count
)
if cache_metadata:
self._apply_cache_to_request(
llm_request, cache_metadata.cache_name, cache_contents_count
if async_creation:
# Launch background cache creation and proceed uncached
key = _cache_task_key(
llm_request.model or "",
current_fingerprint,
cache_contents_count,
)
self._launch_background_cache(
key, llm_request, cache_contents_count
)
logger.debug(
"Async cache creation launched, proceeding uncached"
)
return cache_metadata
return CacheMetadata(
fingerprint=current_fingerprint,
contents_count=cache_contents_count,
)
else:
logger.debug(
"Fingerprints match after invalidation, creating new cache"
)
cache_metadata = await self._create_new_cache_with_contents(
llm_request, cache_contents_count
)
if cache_metadata:
self._apply_cache_to_request(
llm_request,
cache_metadata.cache_name,
cache_contents_count,
)
return cache_metadata

# Fingerprints don't match - recalculate with total contents
logger.debug(
Expand All @@ -127,6 +256,18 @@ async def handle_context_caching(
fingerprint_for_all = self._generate_cache_fingerprint(
llm_request, total_contents_count
)

if async_creation and total_contents_count > 0:
# Launch background cache creation for the new fingerprint
key = _cache_task_key(
llm_request.model or "",
fingerprint_for_all,
total_contents_count,
)
self._launch_background_cache(
key, llm_request, total_contents_count
)

return CacheMetadata(
fingerprint=fingerprint_for_all,
contents_count=total_contents_count,
Expand All @@ -146,6 +287,90 @@ async def handle_context_caching(
contents_count=total_contents_count,
)

def _launch_background_cache(
self,
key: tuple[str, str, int],
llm_request: LlmRequest,
contents_count: int,
) -> None:
"""Launch cache creation as a background asyncio task.

Creates a snapshot of the request data needed for cache creation,
then fires off the creation in a background task.

Args:
key: Registry key for the pending task
llm_request: Request to create cache for (will be snapshotted)
contents_count: Number of contents to cache
"""
if key in _pending_cache_tasks:
task = _pending_cache_tasks[key]
if not task.done():
logger.debug(
"Background cache creation already in progress for key %s",
key,
)
return
del _pending_cache_tasks[key]

# Snapshot the request data before it gets mutated
snapshot = self._snapshot_request(llm_request, contents_count)
genai_client = self.genai_client

async def _do_create() -> Optional[CacheMetadata]:
mgr = GeminiContextCacheManager(genai_client)
return await mgr._create_new_cache_with_contents(
snapshot, contents_count
)

loop = asyncio.get_running_loop()
task = loop.create_task(
_do_create(),
name=f"bg-cache-{key[1][:8]}",
)
_pending_cache_tasks[key] = task
logger.info("Launched background cache creation for key %s", key)

def _snapshot_request(
self,
llm_request: LlmRequest,
contents_count: int,
) -> LlmRequest:
"""Create a minimal snapshot of the request for background cache creation.

Captures only the fields that _create_gemini_cache needs, so the
background task is not affected by mutations to the original request.

Args:
llm_request: Original request to snapshot
contents_count: Number of contents to include

Returns:
A new LlmRequest with just the fields needed for cache creation
"""
config = types.GenerateContentConfig(
system_instruction=(
llm_request.config.system_instruction
if llm_request.config
else None
),
tools=(
llm_request.config.tools if llm_request.config else None
),
tool_config=(
llm_request.config.tool_config if llm_request.config else None
),
)
return LlmRequest(
model=llm_request.model,
contents=list(llm_request.contents[:contents_count]),
config=config,
cache_config=llm_request.cache_config,
cacheable_contents_token_count=(
llm_request.cacheable_contents_token_count
),
)

def _find_count_of_contents_to_cache(
self, contents: list[types.Content]
) -> int:
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/agents/test_context_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def test_str_representation(self):
)

expected = (
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024)"
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024, async_creation=False)"
)
assert str(config) == expected

def test_str_representation_defaults(self):
"""Test string representation with default values."""
config = ContextCacheConfig()

expected = "ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0)"
expected = "ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0, async_creation=False)"
assert str(config) == expected

def test_pydantic_model_validation(self):
Expand Down