Skip to content

Commit ca5f1e9

Browse files
committed
fix: preserve cache_metadata and usage_metadata in VertexAiSessionService event round-trip
VertexAiSessionService was dropping cache_metadata and usage_metadata fields during Event serialization/deserialization. This caused ContextCacheRequestProcessor to never find previous cache metadata, creating a new cache on every LLM call instead of reusing existing ones. The fix adds cache_metadata and usage_metadata to the event_metadata dict during append_event (write path) and restores them in _from_api_event (read path), matching the behavior of other session service implementations. Fixes #4698
1 parent f973673 commit ca5f1e9

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..events.event import Event
3535
from ..events.event_actions import EventActions
3636
from ..events.event_actions import EventCompaction
37+
from ..models.cache_metadata import CacheMetadata
3738
from ..utils.vertex_ai_utils import get_express_mode_api_key
3839
from .base_session_service import BaseSessionService
3940
from .base_session_service import GetSessionConfig
@@ -303,6 +304,14 @@ async def append_event(self, session: Session, event: Event) -> Event:
303304
else None
304305
),
305306
}
307+
if event.usage_metadata:
308+
metadata_dict['usage_metadata'] = event.usage_metadata.model_dump(
309+
exclude_none=True, mode='json'
310+
)
311+
if event.cache_metadata:
312+
metadata_dict['cache_metadata'] = event.cache_metadata.model_dump(
313+
exclude_none=True, mode='json'
314+
)
306315
if event.grounding_metadata:
307316
metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump(
308317
exclude_none=True, mode='json'
@@ -427,6 +436,14 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
427436
getattr(event_metadata, 'grounding_metadata', None),
428437
types.GroundingMetadata,
429438
)
439+
usage_metadata = _session_util.decode_model(
440+
getattr(event_metadata, 'usage_metadata', None),
441+
types.GenerateContentResponseUsageMetadata,
442+
)
443+
cache_metadata = _session_util.decode_model(
444+
getattr(event_metadata, 'cache_metadata', None),
445+
CacheMetadata,
446+
)
430447
else:
431448
long_running_tool_ids = None
432449
partial = None
@@ -437,6 +454,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
437454
compaction_data = None
438455
usage_metadata_data = None
439456
grounding_metadata = None
457+
usage_metadata = None
458+
cache_metadata = None
440459

441460
if actions:
442461
actions_dict = actions.model_dump(exclude_none=True, mode='python')
@@ -478,6 +497,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
478497
branch=branch,
479498
custom_metadata=custom_metadata,
480499
grounding_metadata=grounding_metadata,
500+
usage_metadata=usage_metadata,
501+
cache_metadata=cache_metadata,
481502
long_running_tool_ids=long_running_tool_ids,
482503
usage_metadata=usage_metadata,
483504
)

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.adk.events.event import Event
2929
from google.adk.events.event_actions import EventActions
3030
from google.adk.events.event_actions import EventCompaction
31+
from google.adk.models.cache_metadata import CacheMetadata
3132
from google.adk.sessions.base_session_service import GetSessionConfig
3233
from google.adk.sessions.session import Session
3334
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
@@ -249,6 +250,8 @@ def _convert_to_object(data):
249250
'artifact_delta',
250251
'custom_metadata',
251252
'requested_auth_configs',
253+
'cache_metadata',
254+
'usage_metadata',
252255
]:
253256
kwargs[key] = value
254257
else:
@@ -1039,3 +1042,63 @@ async def test_append_event_with_usage_metadata_and_compaction():
10391042
assert appended_event.custom_metadata == {'extra': 'info'}
10401043
assert '_compaction' not in (appended_event.custom_metadata or {})
10411044
assert '_usage_metadata' not in (appended_event.custom_metadata or {})
1045+
1046+
1047+
@pytest.mark.asyncio
1048+
@pytest.mark.usefixtures('mock_get_api_client')
1049+
async def test_append_event_with_cache_and_usage_metadata():
1050+
"""cache_metadata and usage_metadata round-trip through append and get."""
1051+
session_service = mock_vertex_ai_session_service()
1052+
session = await session_service.get_session(
1053+
app_name='123', user_id='user', session_id='1'
1054+
)
1055+
assert session is not None
1056+
1057+
cache_meta = CacheMetadata(
1058+
cache_name='projects/123/locations/us-central1/cachedContents/456',
1059+
expire_time=9999999999.0,
1060+
fingerprint='abc123hash',
1061+
invocations_used=3,
1062+
contents_count=10,
1063+
created_at=1700000000.0,
1064+
)
1065+
usage_meta = genai_types.GenerateContentResponseUsageMetadata(
1066+
prompt_token_count=100,
1067+
candidates_token_count=50,
1068+
total_token_count=150,
1069+
cached_content_token_count=80,
1070+
)
1071+
event_to_append = Event(
1072+
invocation_id='cache_test_invocation',
1073+
author='model',
1074+
timestamp=1734005536.0,
1075+
content=genai_types.Content(
1076+
parts=[genai_types.Part(text='cached response')]
1077+
),
1078+
cache_metadata=cache_meta,
1079+
usage_metadata=usage_meta,
1080+
)
1081+
1082+
await session_service.append_event(session, event_to_append)
1083+
1084+
retrieved_session = await session_service.get_session(
1085+
app_name='123', user_id='user', session_id='1'
1086+
)
1087+
assert retrieved_session is not None
1088+
1089+
appended_event = retrieved_session.events[-1]
1090+
# cache_metadata is preserved
1091+
assert appended_event.cache_metadata is not None
1092+
assert appended_event.cache_metadata.cache_name == (
1093+
'projects/123/locations/us-central1/cachedContents/456'
1094+
)
1095+
assert appended_event.cache_metadata.fingerprint == 'abc123hash'
1096+
assert appended_event.cache_metadata.invocations_used == 3
1097+
assert appended_event.cache_metadata.contents_count == 10
1098+
assert appended_event.cache_metadata.created_at == 1700000000.0
1099+
# usage_metadata is preserved
1100+
assert appended_event.usage_metadata is not None
1101+
assert appended_event.usage_metadata.prompt_token_count == 100
1102+
assert appended_event.usage_metadata.candidates_token_count == 50
1103+
assert appended_event.usage_metadata.total_token_count == 150
1104+
assert appended_event.usage_metadata.cached_content_token_count == 80

0 commit comments

Comments
 (0)