Skip to content

Commit eae3365

Browse files
committed
fix: preserve cache_metadata and usage_metadata in VertexAiSessionService event round-trip
VertexAiSessionService was dropping cache_metadata and usage_metadata fields during Event serialization/deserialization. This caused ContextCacheRequestProcessor to never find previous cache metadata, creating a new cache on every LLM call instead of reusing existing ones. The fix adds cache_metadata and usage_metadata to the event_metadata dict during append_event (write path) and restores them in _from_api_event (read path), matching the behavior of other session service implementations. Fixes #4698
1 parent 5770cd3 commit eae3365

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from . import _session_util
3434
from ..events.event import Event
3535
from ..events.event_actions import EventActions
36+
from ..events.event_actions import EventCompaction
37+
from ..models.cache_metadata import CacheMetadata
3638
from ..utils.vertex_ai_utils import get_express_mode_api_key
3739
from .base_session_service import BaseSessionService
3840
from .base_session_service import GetSessionConfig
@@ -287,6 +289,14 @@ async def append_event(self, session: Session, event: Event) -> Event:
287289
else None
288290
),
289291
}
292+
if event.usage_metadata:
293+
metadata_dict['usage_metadata'] = event.usage_metadata.model_dump(
294+
exclude_none=True, mode='json'
295+
)
296+
if event.cache_metadata:
297+
metadata_dict['cache_metadata'] = event.cache_metadata.model_dump(
298+
exclude_none=True, mode='json'
299+
)
290300
if event.grounding_metadata:
291301
metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump(
292302
exclude_none=True, mode='json'
@@ -374,6 +384,14 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
374384
getattr(event_metadata, 'grounding_metadata', None),
375385
types.GroundingMetadata,
376386
)
387+
usage_metadata = _session_util.decode_model(
388+
getattr(event_metadata, 'usage_metadata', None),
389+
types.GenerateContentResponseUsageMetadata,
390+
)
391+
cache_metadata = _session_util.decode_model(
392+
getattr(event_metadata, 'cache_metadata', None),
393+
CacheMetadata,
394+
)
377395
else:
378396
long_running_tool_ids = None
379397
partial = None
@@ -382,6 +400,8 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
382400
branch = None
383401
custom_metadata = None
384402
grounding_metadata = None
403+
usage_metadata = None
404+
cache_metadata = None
385405

386406
return Event(
387407
id=api_event_obj.name.split('/')[-1],
@@ -400,5 +420,7 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
400420
branch=branch,
401421
custom_metadata=custom_metadata,
402422
grounding_metadata=grounding_metadata,
423+
usage_metadata=usage_metadata,
424+
cache_metadata=cache_metadata,
403425
long_running_tool_ids=long_running_tool_ids,
404426
)

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from google.adk.auth.auth_tool import AuthConfig
2828
from google.adk.events.event import Event
2929
from google.adk.events.event_actions import EventActions
30+
from google.adk.events.event_actions import EventCompaction
31+
from google.adk.models.cache_metadata import CacheMetadata
3032
from google.adk.sessions.base_session_service import GetSessionConfig
3133
from google.adk.sessions.session import Session
3234
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
@@ -248,6 +250,8 @@ def _convert_to_object(data):
248250
'artifact_delta',
249251
'custom_metadata',
250252
'requested_auth_configs',
253+
'cache_metadata',
254+
'usage_metadata',
251255
]:
252256
kwargs[key] = value
253257
else:
@@ -826,3 +830,147 @@ async def test_append_event():
826830
assert len(retrieved_session.events) == 2
827831
event_to_append.id = retrieved_session.events[1].id
828832
assert retrieved_session.events[1] == event_to_append
833+
834+
835+
@pytest.mark.asyncio
836+
@pytest.mark.usefixtures('mock_get_api_client')
837+
async def test_append_event_with_compaction():
838+
"""Compaction data round-trips through append_event and get_session."""
839+
session_service = mock_vertex_ai_session_service()
840+
session = await session_service.get_session(
841+
app_name='123', user_id='user', session_id='1'
842+
)
843+
assert session is not None
844+
845+
compaction = EventCompaction(
846+
start_timestamp=1000.0,
847+
end_timestamp=2000.0,
848+
compacted_content=genai_types.Content(
849+
parts=[genai_types.Part(text='compacted summary')]
850+
),
851+
)
852+
event_to_append = Event(
853+
invocation_id='compaction_invocation',
854+
author='model',
855+
timestamp=1734005534.0,
856+
actions=EventActions(compaction=compaction),
857+
)
858+
859+
await session_service.append_event(session, event_to_append)
860+
861+
retrieved_session = await session_service.get_session(
862+
app_name='123', user_id='user', session_id='1'
863+
)
864+
assert retrieved_session is not None
865+
866+
appended_event = retrieved_session.events[-1]
867+
assert appended_event.actions.compaction is not None
868+
assert appended_event.actions.compaction.start_timestamp == 1000.0
869+
assert appended_event.actions.compaction.end_timestamp == 2000.0
870+
assert appended_event.actions.compaction.compacted_content.parts[0].text == (
871+
'compacted summary'
872+
)
873+
# custom_metadata should remain None when only compaction was stored
874+
assert appended_event.custom_metadata is None
875+
876+
877+
@pytest.mark.asyncio
878+
@pytest.mark.usefixtures('mock_get_api_client')
879+
async def test_append_event_with_compaction_and_custom_metadata():
880+
"""Both compaction and user custom_metadata survive the round-trip."""
881+
session_service = mock_vertex_ai_session_service()
882+
session = await session_service.get_session(
883+
app_name='123', user_id='user', session_id='1'
884+
)
885+
assert session is not None
886+
887+
compaction = EventCompaction(
888+
start_timestamp=100.0,
889+
end_timestamp=200.0,
890+
compacted_content=genai_types.Content(
891+
parts=[genai_types.Part(text='summary')]
892+
),
893+
)
894+
event_to_append = Event(
895+
invocation_id='compaction_and_meta_invocation',
896+
author='model',
897+
timestamp=1734005535.0,
898+
actions=EventActions(compaction=compaction),
899+
custom_metadata={'user_key': 'user_value'},
900+
)
901+
902+
await session_service.append_event(session, event_to_append)
903+
904+
retrieved_session = await session_service.get_session(
905+
app_name='123', user_id='user', session_id='1'
906+
)
907+
assert retrieved_session is not None
908+
909+
appended_event = retrieved_session.events[-1]
910+
# Compaction is restored
911+
assert appended_event.actions.compaction is not None
912+
assert appended_event.actions.compaction.start_timestamp == 100.0
913+
assert appended_event.actions.compaction.end_timestamp == 200.0
914+
# User custom_metadata is preserved without the internal _compaction key
915+
assert appended_event.custom_metadata == {'user_key': 'user_value'}
916+
assert '_compaction' not in (appended_event.custom_metadata or {})
917+
918+
919+
@pytest.mark.asyncio
920+
@pytest.mark.usefixtures('mock_get_api_client')
921+
async def test_append_event_with_cache_and_usage_metadata():
922+
"""cache_metadata and usage_metadata round-trip through append and get."""
923+
session_service = mock_vertex_ai_session_service()
924+
session = await session_service.get_session(
925+
app_name='123', user_id='user', session_id='1'
926+
)
927+
assert session is not None
928+
929+
cache_meta = CacheMetadata(
930+
cache_name='projects/123/locations/us-central1/cachedContents/456',
931+
expire_time=9999999999.0,
932+
fingerprint='abc123hash',
933+
invocations_used=3,
934+
contents_count=10,
935+
created_at=1700000000.0,
936+
)
937+
usage_meta = genai_types.GenerateContentResponseUsageMetadata(
938+
prompt_token_count=100,
939+
candidates_token_count=50,
940+
total_token_count=150,
941+
cached_content_token_count=80,
942+
)
943+
event_to_append = Event(
944+
invocation_id='cache_test_invocation',
945+
author='model',
946+
timestamp=1734005536.0,
947+
content=genai_types.Content(
948+
parts=[genai_types.Part(text='cached response')]
949+
),
950+
cache_metadata=cache_meta,
951+
usage_metadata=usage_meta,
952+
)
953+
954+
await session_service.append_event(session, event_to_append)
955+
956+
retrieved_session = await session_service.get_session(
957+
app_name='123', user_id='user', session_id='1'
958+
)
959+
assert retrieved_session is not None
960+
961+
appended_event = retrieved_session.events[-1]
962+
# cache_metadata is preserved
963+
assert appended_event.cache_metadata is not None
964+
assert appended_event.cache_metadata.cache_name == (
965+
'projects/123/locations/us-central1/cachedContents/456'
966+
)
967+
assert appended_event.cache_metadata.fingerprint == 'abc123hash'
968+
assert appended_event.cache_metadata.invocations_used == 3
969+
assert appended_event.cache_metadata.contents_count == 10
970+
assert appended_event.cache_metadata.created_at == 1700000000.0
971+
# usage_metadata is preserved
972+
assert appended_event.usage_metadata is not None
973+
assert appended_event.usage_metadata.prompt_token_count == 100
974+
assert appended_event.usage_metadata.candidates_token_count == 50
975+
assert appended_event.usage_metadata.total_token_count == 150
976+
assert appended_event.usage_metadata.cached_content_token_count == 80

0 commit comments

Comments
 (0)