Skip to content

Commit 4360ba4

Browse files
feat: add create_http_options to ContextCacheConfig for cache creation timeout
Adds a `create_http_options` field to `ContextCacheConfig` that is passed through to `CreateCachedContentConfig` when creating a cache. This allows users to set a timeout (or other HTTP options) on the CachedContent.create() call, which can take 30-40 seconds on Vertex AI. When the timeout is exceeded, cache creation fails gracefully and the request proceeds without caching. Replaces the previous `async_creation` approach which required global in-memory state that didn't scale across instances. Fixes #4703
1 parent 452a9fa commit 4360ba4

4 files changed

Lines changed: 100 additions & 250 deletions

File tree

src/google/adk/agents/context_cache_config.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Optional
18+
19+
from google.genai import types
1720
from pydantic import BaseModel
1821
from pydantic import ConfigDict
1922
from pydantic import Field
@@ -38,10 +41,12 @@ class ContextCacheConfig(BaseModel):
3841
cache_intervals: Maximum number of invocations to reuse the same cache before refreshing it
3942
ttl_seconds: Time-to-live for cache in seconds
4043
min_tokens: Minimum tokens required to enable caching
44+
create_http_options: HTTP options for cache creation API calls
4145
"""
4246

4347
model_config = ConfigDict(
4448
extra="forbid",
49+
arbitrary_types_allowed=True,
4550
)
4651

4752
cache_intervals: int = Field(
@@ -72,13 +77,15 @@ class ContextCacheConfig(BaseModel):
7277
),
7378
)
7479

75-
async_creation: bool = Field(
76-
default=False,
80+
create_http_options: Optional[types.HttpOptions] = Field(
81+
default=None,
7782
description=(
78-
"When True, cache creation is performed in the background instead of"
79-
" blocking the current request. The current request proceeds uncached"
80-
" and the cache is available for the next request. This eliminates"
81-
" latency spikes from slow CachedContent.create() API calls."
83+
"HTTP options for cache creation API calls. Use this to set a"
84+
" timeout on CachedContent.create() calls (e.g."
85+
" types.HttpOptions(timeout=10000) for a 10-second timeout in"
86+
" milliseconds). When the cache creation call exceeds the timeout,"
87+
" it fails and the request proceeds without caching. None uses the"
88+
" client's default HTTP options."
8289
),
8390
)
8491

@@ -92,5 +99,5 @@ def __str__(self) -> str:
9299
return (
93100
f"ContextCacheConfig(cache_intervals={self.cache_intervals}, "
94101
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens}, "
95-
f"async_creation={self.async_creation})"
102+
f"create_http_options={self.create_http_options})"
96103
)

src/google/adk/models/gemini_context_cache_manager.py

Lines changed: 17 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from __future__ import annotations
1818

19-
import asyncio
2019
import hashlib
2120
import json
2221
import logging
@@ -33,72 +32,6 @@
3332

3433
logger = logging.getLogger("google_adk." + __name__)
3534

36-
# Background cache task registry for async_creation mode.
37-
# Key: (model, fingerprint, contents_count)
38-
# Value: asyncio.Task that resolves to Optional[CacheMetadata]
39-
_pending_cache_tasks: dict[tuple[str, str, int], asyncio.Task] = {}
40-
41-
42-
def _cache_task_key(
43-
model: str, fingerprint: str, contents_count: int
44-
) -> tuple[str, str, int]:
45-
"""Build a registry key for a pending cache task."""
46-
return (model, fingerprint, contents_count)
47-
48-
49-
def _check_pending_cache(
50-
key: tuple[str, str, int],
51-
) -> Optional[CacheMetadata]:
52-
"""Check if a background cache task completed successfully.
53-
54-
Returns the CacheMetadata if the task is done and succeeded,
55-
None if the task is still running or failed.
56-
Cleans up the registry entry in either done case.
57-
"""
58-
task = _pending_cache_tasks.get(key)
59-
if task is None:
60-
return None
61-
62-
if not task.done():
63-
return None
64-
65-
# Task is done - remove from registry regardless of outcome
66-
del _pending_cache_tasks[key]
67-
68-
if task.cancelled():
69-
logger.warning("Background cache task was cancelled for key %s", key)
70-
return None
71-
72-
exc = task.exception()
73-
if exc is not None:
74-
logger.warning("Background cache task failed for key %s: %s", key, exc)
75-
return None
76-
77-
return task.result()
78-
79-
80-
def _cleanup_stale_tasks() -> None:
81-
"""Remove completed tasks that have been sitting unclaimed."""
82-
to_remove = []
83-
for key, task in _pending_cache_tasks.items():
84-
if not task.done():
85-
continue
86-
result = None
87-
try:
88-
if not task.cancelled():
89-
exc = task.exception()
90-
if exc is None:
91-
result = task.result()
92-
except Exception:
93-
pass
94-
if result is None or (
95-
result.expire_time is not None and time.time() >= result.expire_time
96-
):
97-
to_remove.append(key)
98-
for key in to_remove:
99-
del _pending_cache_tasks[key]
100-
101-
10235
if TYPE_CHECKING:
10336
from google.genai import Client
10437

@@ -129,55 +62,13 @@ async def handle_context_caching(
12962
the cache to the request by setting cached_content and removing cached
13063
contents from the request.
13164
132-
When async_creation is enabled in the cache config, cache creation is
133-
performed in the background instead of blocking the current request.
134-
13565
Args:
13666
llm_request: Request that may contain cache config and metadata.
13767
Modified in-place to use the cache.
13868
13969
Returns:
14070
Cache metadata to be included in response, or None if caching failed
14171
"""
142-
async_creation = (
143-
llm_request.cache_config
144-
and llm_request.cache_config.async_creation
145-
)
146-
147-
# Opportunistically clean up stale background tasks
148-
if async_creation:
149-
_cleanup_stale_tasks()
150-
151-
# Check for completed background cache creation (async_creation mode)
152-
if (
153-
async_creation
154-
and llm_request.cache_metadata
155-
and llm_request.cache_metadata.cache_name is None
156-
):
157-
fp = llm_request.cache_metadata.fingerprint
158-
cc = llm_request.cache_metadata.contents_count
159-
model = llm_request.model or ""
160-
key = _cache_task_key(model, fp, cc)
161-
bg_result = _check_pending_cache(key)
162-
if bg_result and bg_result.cache_name:
163-
if time.time() < bg_result.expire_time:
164-
logger.info(
165-
"Using background-created cache: %s",
166-
bg_result.cache_name,
167-
)
168-
self._apply_cache_to_request(
169-
llm_request,
170-
bg_result.cache_name,
171-
bg_result.contents_count,
172-
)
173-
return bg_result
174-
else:
175-
logger.info(
176-
"Background-created cache already expired: %s",
177-
bg_result.cache_name,
178-
)
179-
await self.cleanup_cache(bg_result.cache_name)
180-
18172
# Check if we have existing cache metadata and if it's valid
18273
if llm_request.cache_metadata:
18374
logger.debug(
@@ -216,37 +107,19 @@ async def handle_context_caching(
216107

217108
# If fingerprints match, create new cache (expired but same content)
218109
if current_fingerprint == old_cache_metadata.fingerprint:
219-
if async_creation:
220-
# Launch background cache creation and proceed uncached
221-
key = _cache_task_key(
222-
llm_request.model or "",
223-
current_fingerprint,
110+
logger.debug(
111+
"Fingerprints match after invalidation, creating new cache"
112+
)
113+
cache_metadata = await self._create_new_cache_with_contents(
114+
llm_request, cache_contents_count
115+
)
116+
if cache_metadata:
117+
self._apply_cache_to_request(
118+
llm_request,
119+
cache_metadata.cache_name,
224120
cache_contents_count,
225121
)
226-
self._launch_background_cache(
227-
key, llm_request, cache_contents_count
228-
)
229-
logger.debug(
230-
"Async cache creation launched, proceeding uncached"
231-
)
232-
return CacheMetadata(
233-
fingerprint=current_fingerprint,
234-
contents_count=cache_contents_count,
235-
)
236-
else:
237-
logger.debug(
238-
"Fingerprints match after invalidation, creating new cache"
239-
)
240-
cache_metadata = await self._create_new_cache_with_contents(
241-
llm_request, cache_contents_count
242-
)
243-
if cache_metadata:
244-
self._apply_cache_to_request(
245-
llm_request,
246-
cache_metadata.cache_name,
247-
cache_contents_count,
248-
)
249-
return cache_metadata
122+
return cache_metadata
250123

251124
# Fingerprints don't match - recalculate with total contents
252125
logger.debug(
@@ -257,17 +130,6 @@ async def handle_context_caching(
257130
llm_request, total_contents_count
258131
)
259132

260-
if async_creation and total_contents_count > 0:
261-
# Launch background cache creation for the new fingerprint
262-
key = _cache_task_key(
263-
llm_request.model or "",
264-
fingerprint_for_all,
265-
total_contents_count,
266-
)
267-
self._launch_background_cache(
268-
key, llm_request, total_contents_count
269-
)
270-
271133
return CacheMetadata(
272134
fingerprint=fingerprint_for_all,
273135
contents_count=total_contents_count,
@@ -287,90 +149,6 @@ async def handle_context_caching(
287149
contents_count=total_contents_count,
288150
)
289151

290-
def _launch_background_cache(
291-
self,
292-
key: tuple[str, str, int],
293-
llm_request: LlmRequest,
294-
contents_count: int,
295-
) -> None:
296-
"""Launch cache creation as a background asyncio task.
297-
298-
Creates a snapshot of the request data needed for cache creation,
299-
then fires off the creation in a background task.
300-
301-
Args:
302-
key: Registry key for the pending task
303-
llm_request: Request to create cache for (will be snapshotted)
304-
contents_count: Number of contents to cache
305-
"""
306-
if key in _pending_cache_tasks:
307-
task = _pending_cache_tasks[key]
308-
if not task.done():
309-
logger.debug(
310-
"Background cache creation already in progress for key %s",
311-
key,
312-
)
313-
return
314-
del _pending_cache_tasks[key]
315-
316-
# Snapshot the request data before it gets mutated
317-
snapshot = self._snapshot_request(llm_request, contents_count)
318-
genai_client = self.genai_client
319-
320-
async def _do_create() -> Optional[CacheMetadata]:
321-
mgr = GeminiContextCacheManager(genai_client)
322-
return await mgr._create_new_cache_with_contents(
323-
snapshot, contents_count
324-
)
325-
326-
loop = asyncio.get_running_loop()
327-
task = loop.create_task(
328-
_do_create(),
329-
name=f"bg-cache-{key[1][:8]}",
330-
)
331-
_pending_cache_tasks[key] = task
332-
logger.info("Launched background cache creation for key %s", key)
333-
334-
def _snapshot_request(
335-
self,
336-
llm_request: LlmRequest,
337-
contents_count: int,
338-
) -> LlmRequest:
339-
"""Create a minimal snapshot of the request for background cache creation.
340-
341-
Captures only the fields that _create_gemini_cache needs, so the
342-
background task is not affected by mutations to the original request.
343-
344-
Args:
345-
llm_request: Original request to snapshot
346-
contents_count: Number of contents to include
347-
348-
Returns:
349-
A new LlmRequest with just the fields needed for cache creation
350-
"""
351-
config = types.GenerateContentConfig(
352-
system_instruction=(
353-
llm_request.config.system_instruction
354-
if llm_request.config
355-
else None
356-
),
357-
tools=(
358-
llm_request.config.tools if llm_request.config else None
359-
),
360-
tool_config=(
361-
llm_request.config.tool_config if llm_request.config else None
362-
),
363-
)
364-
return LlmRequest(
365-
model=llm_request.model,
366-
contents=list(llm_request.contents[:contents_count]),
367-
config=config,
368-
cache_config=llm_request.cache_config,
369-
cacheable_contents_token_count=(
370-
llm_request.cacheable_contents_token_count
371-
),
372-
)
373-
374152
def _find_count_of_contents_to_cache(
375153
self, contents: list[types.Content]
376154
) -> int:
@@ -611,6 +389,12 @@ async def _create_gemini_cache(
611389
if llm_request.config and llm_request.config.tool_config:
612390
cache_config.tool_config = llm_request.config.tool_config
613391

392+
# Pass through HTTP options (e.g. timeout) from cache config
393+
if llm_request.cache_config.create_http_options:
394+
cache_config.http_options = (
395+
llm_request.cache_config.create_http_options
396+
)
397+
614398
span.set_attribute("cache_contents_count", cache_contents_count)
615399
span.set_attribute("model", llm_request.model)
616400
span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds)

0 commit comments

Comments
 (0)