Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/google/adk/models/cache_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CacheMetadata(BaseModel):
None when no active cache exists.
contents_count: Number of contents. When active cache exists, this is
the count of cached contents. When no active cache exists, this is
the total count of contents in the request.
the count of the cacheable content prefix used for fingerprinting.
created_at: Unix timestamp when the cache was created. None when
no active cache exists.
"""
Expand Down Expand Up @@ -87,7 +87,7 @@ class CacheMetadata(BaseModel):
ge=0,
description=(
"Number of contents (cached contents when active cache exists, "
"total contents in request when no active cache)"
"cacheable content prefix when no active cache)"
),
)

Expand Down
24 changes: 15 additions & 9 deletions src/google/adk/models/gemini_context_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,31 +135,37 @@ async def handle_context_caching(
contents_count=cache_contents_count,
)

# Fingerprints don't match - recalculate with total contents
# Fingerprints don't match - recalculate with the current cacheable
# prefix. Request-scoped user contents, such as dynamic instructions,
# should not become part of the fingerprint-only chain.
logger.debug(
"Fingerprints don't match, returning fingerprint-only metadata"
)
total_contents_count = len(llm_request.contents)
fingerprint_for_all = self._generate_cache_fingerprint(
llm_request, total_contents_count
cache_contents_count = self._find_count_of_contents_to_cache(
llm_request.contents
)
fingerprint = self._generate_cache_fingerprint(
llm_request, cache_contents_count
)
return CacheMetadata(
fingerprint=fingerprint_for_all,
contents_count=total_contents_count,
fingerprint=fingerprint,
contents_count=cache_contents_count,
)

# No existing cache metadata - return fingerprint-only metadata
# We don't create cache without previous fingerprint to match
logger.debug(
"No existing cache metadata, creating fingerprint-only metadata"
)
total_contents_count = len(llm_request.contents)
cache_contents_count = self._find_count_of_contents_to_cache(
llm_request.contents
)
fingerprint = self._generate_cache_fingerprint(
llm_request, total_contents_count
llm_request, cache_contents_count
)
return CacheMetadata(
fingerprint=fingerprint,
contents_count=total_contents_count,
contents_count=cache_contents_count,
)

def _find_count_of_contents_to_cache(
Expand Down
87 changes: 77 additions & 10 deletions tests/unittests/agents/test_gemini_context_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def test_handle_context_caching_no_existing_cache(self):
assert result.invocations_used is None
assert result.created_at is None
assert result.fingerprint == "test_fp"
assert result.contents_count == 5 # Total contents count
assert result.contents_count == 0

# No cache should be created
self.manager.genai_client.aio.caches.create.assert_not_called()
Expand Down Expand Up @@ -233,7 +233,7 @@ async def test_handle_context_caching_invalid_cache_fingerprint_mismatch(
assert result.invocations_used is None
assert result.created_at is None
assert result.fingerprint == "new_fp"
assert result.contents_count == 5 # Total contents count
assert result.contents_count == 0
mock_cleanup.assert_called_once_with(existing_cache.cache_name)
self.manager.genai_client.aio.caches.create.assert_not_called()

Expand Down Expand Up @@ -584,7 +584,7 @@ async def test_cache_creation_with_sufficient_token_count(self):
assert result is not None
assert result.cache_name is None # Fingerprint-only state
assert result.fingerprint == "test_fp"
assert result.contents_count == 3
assert result.contents_count == 0
self.manager.genai_client.aio.caches.create.assert_not_called()

async def test_cache_creation_with_insufficient_token_count(self):
Expand Down Expand Up @@ -752,12 +752,10 @@ async def test_multi_turn_fingerprint_stable_when_below_token_threshold(
contents_counts_seen.append(result.contents_count)
metadata = result

# First turn has no metadata, so uses total (1).
# Subsequent turns preserve contents_count=1 from the prefix.
# Fingerprint stays stable because contents[:1] is always the
# same user message.
# All contents in this helper are user-role messages, so there is no
# cacheable content prefix before the final user batch.
assert len(set(fingerprints_seen)) == 1
assert contents_counts_seen == [1, 1, 1]
assert contents_counts_seen == [0, 0, 0]

async def test_contents_count_should_remain_stable_after_cache_creation_failure(
self,
Expand Down Expand Up @@ -911,7 +909,7 @@ async def test_fingerprint_only_metadata_transitions_to_active_cache(

assert result_1 is not None
assert result_1.cache_name is None
assert result_1.contents_count == 3
assert result_1.contents_count == 0

# --- Second LLM call: carry forward fingerprint-only metadata ---
# Contents grew but we still have same prefix
Expand Down Expand Up @@ -948,6 +946,75 @@ async def test_fingerprint_only_metadata_transitions_to_active_cache(
assert result_2.cache_name == (
"projects/test/locations/us-central1/cachedContents/new789"
)
assert result_2.contents_count == 3 # Preserved from prefix
assert result_2.contents_count == 0 # Preserved from prefix
assert result_2.invocations_used == 1
self.manager.genai_client.aio.caches.create.assert_called_once()

async def test_dynamic_instruction_does_not_break_initial_cache_fingerprint(
self,
):
"""Request-scoped dynamic instructions stay out of the cache prefix."""
dynamic_instruction = types.Content(
role="user", parts=[types.Part(text="Turn context: locale=en-US")]
)
user_msg = types.Content(
role="user", parts=[types.Part(text="what time is it?")]
)
model_tool_call = types.Content(
role="model",
parts=[
types.Part(
function_call=types.FunctionCall(
name="get_time", args={}
)
)
],
)
tool_response = types.Content(
role="user",
parts=[
types.Part(
function_response=types.FunctionResponse(
name="get_time", response={"time": "12:00"}
)
)
],
)

request_1 = self.create_llm_request(contents_count=0)
request_1.contents = [dynamic_instruction, user_msg]

result_1 = await self.manager.handle_context_caching(request_1)

assert result_1 is not None
assert result_1.cache_name is None
assert result_1.contents_count == 0

request_2 = self.create_llm_request(
cache_metadata=result_1, contents_count=0
)
request_2.contents = [
user_msg,
model_tool_call,
dynamic_instruction,
tool_response,
]
request_2.cacheable_contents_token_count = 4096

mock_cached_content = AsyncMock()
mock_cached_content.name = (
"projects/test/locations/us-central1/cachedContents/new789"
)
self.manager.genai_client.aio.caches.create = AsyncMock(
return_value=mock_cached_content
)

result_2 = await self.manager.handle_context_caching(request_2)

assert result_2 is not None
assert result_2.cache_name == (
"projects/test/locations/us-central1/cachedContents/new789"
)
assert result_2.contents_count == 0
assert result_2.invocations_used == 1
self.manager.genai_client.aio.caches.create.assert_called_once()