Skip to content

Commit c08fee5

Browse files
google-genai-botGWeale
authored andcommitted
feat: Use raw_event to store event data in vertex ai session service
PiperOrigin-RevId: 895379069 Change-Id: I30002d075e261e20d76b0b7ec9eaba97ee6dc656
1 parent ffa62d6 commit c08fee5

File tree

2 files changed

+216
-9
lines changed

2 files changed

+216
-9
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
import copy
1718
import datetime
1819
import logging
1920
import re
@@ -24,6 +25,7 @@
2425

2526
from google.genai import types
2627
from google.genai.errors import ClientError
28+
import pydantic
2729
from typing_extensions import override
2830

2931
if TYPE_CHECKING:
@@ -333,17 +335,41 @@ async def append_event(self, session: Session, event: Event) -> Event:
333335
value=usage_dict,
334336
)
335337
config['event_metadata'] = metadata_dict
338+
config['raw_event'] = event.model_dump(
339+
exclude_none=True,
340+
mode='json',
341+
by_alias=True,
342+
)
336343

344+
# Retry without raw_event if client side validation fails for older SDK
345+
# versions.
337346
async with self._get_api_client() as api_client:
338-
await api_client.agent_engines.sessions.events.append(
339-
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}',
340-
author=event.author,
341-
invocation_id=event.invocation_id,
342-
timestamp=datetime.datetime.fromtimestamp(
343-
event.timestamp, tz=datetime.timezone.utc
344-
),
345-
config=config,
346-
)
347+
try:
348+
await api_client.agent_engines.sessions.events.append(
349+
name=(
350+
f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}'
351+
),
352+
author=event.author,
353+
invocation_id=event.invocation_id,
354+
timestamp=datetime.datetime.fromtimestamp(
355+
event.timestamp, tz=datetime.timezone.utc
356+
),
357+
config=config,
358+
)
359+
except pydantic.ValidationError:
360+
if 'raw_event' in config:
361+
del config['raw_event']
362+
await api_client.agent_engines.sessions.events.append(
363+
name=(
364+
f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}'
365+
),
366+
author=event.author,
367+
invocation_id=event.invocation_id,
368+
timestamp=datetime.datetime.fromtimestamp(
369+
event.timestamp, tz=datetime.timezone.utc
370+
),
371+
config=config,
372+
)
347373
return event
348374

349375
def _get_reasoning_engine_id(self, app_name: str):
@@ -389,8 +415,33 @@ def _get_api_client(self) -> vertexai.AsyncClient:
389415
).aio
390416

391417

418+
def _get_raw_event(api_event_obj: Any) -> dict[str, Any] | None:
419+
"""Extracts raw_event dict from SessionEvent object safely."""
420+
try:
421+
return api_event_obj.raw_event
422+
except AttributeError:
423+
try:
424+
return api_event_obj.rawEvent
425+
except AttributeError:
426+
return None
427+
428+
392429
def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
393430
"""Converts an API event object to an Event object."""
431+
# Read event data from raw_event first before falling back to top level
432+
# fields.
433+
raw_event_dict = _get_raw_event(api_event_obj)
434+
if raw_event_dict:
435+
event_dict = copy.deepcopy(raw_event_dict)
436+
timestamp_obj = getattr(api_event_obj, 'timestamp', None)
437+
event_dict.update({
438+
'id': api_event_obj.name.split('/')[-1],
439+
'invocation_id': getattr(api_event_obj, 'invocation_id', None),
440+
'author': getattr(api_event_obj, 'author', None),
441+
'timestamp': timestamp_obj.timestamp() if timestamp_obj else None,
442+
})
443+
return Event.model_validate(event_dict)
444+
394445
actions = getattr(api_event_obj, 'actions', None)
395446
event_metadata = getattr(api_event_obj, 'event_metadata', None)
396447
if event_metadata:

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.adk.events.event import Event
2929
from google.adk.events.event_actions import EventActions
3030
from google.adk.events.event_actions import EventCompaction
31+
from google.adk.models.cache_metadata import CacheMetadata
3132
from google.adk.sessions.base_session_service import GetSessionConfig
3233
from google.adk.sessions.session import Session
3334
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
@@ -91,6 +92,7 @@
9192
'branch': '',
9293
'long_running_tool_ids': ['tool1'],
9394
},
95+
'raw_event': {},
9496
},
9597
]
9698
MOCK_EVENT_JSON_2 = [
@@ -162,6 +164,96 @@ def _generate_mock_events_for_session_5(num_events):
162164
MANY_EVENTS_COUNT = 200
163165
MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT)
164166

167+
MOCK_EVENT_WITH_OVERRIDE_JSON = [{
168+
'name': (
169+
'projects/test-project/locations/test-location/'
170+
'reasoningEngines/123/sessions/override/events/1'
171+
),
172+
'invocationId': 'override_invoke',
173+
'author': 'user_with_override',
174+
'timestamp': '2024-12-12T12:12:12.123456Z',
175+
'content': {
176+
'parts': [
177+
{'text': 'top_level_content'},
178+
],
179+
},
180+
'actions': {
181+
'transferToAgent': 'top_level_agent',
182+
},
183+
'eventMetadata': {
184+
'partial': True,
185+
'turnComplete': False,
186+
'interrupted': False,
187+
'branch': 'top_level_branch',
188+
},
189+
'errorCode': '111',
190+
'errorMessage': 'top_level_error',
191+
'rawEvent': {
192+
'invocationId': 'wrong_invocation_id',
193+
'author': 'wrong_author',
194+
'content': {
195+
'parts': [
196+
{'text': 'raw_event_content'},
197+
],
198+
},
199+
'actions': {
200+
'transferToAgent': 'raw_event_agent',
201+
},
202+
'partial': False,
203+
'turnComplete': True,
204+
'interrupted': True,
205+
'branch': 'raw_event_branch',
206+
'errorCode': '222',
207+
'errorMessage': 'raw_event_error',
208+
},
209+
}]
210+
211+
MOCK_EVENT_WITH_OVERRIDE_JSON_2 = [{
212+
'name': (
213+
'projects/test-project/locations/test-location/'
214+
'reasoningEngines/123/sessions/override/events/1'
215+
),
216+
'invocationId': 'override_invoke',
217+
'author': 'user_with_override',
218+
'content': {},
219+
'actions': {},
220+
'timestamp': '2024-12-12T12:12:12.123456Z',
221+
'rawEvent': {
222+
'invocationId': 'wrong_invocation_id',
223+
'author': 'wrong_author',
224+
'content': {
225+
'parts': [
226+
{'text': 'raw_event_content'},
227+
],
228+
},
229+
'actions': {
230+
'skipSummarization': None,
231+
'stateDelta': {},
232+
'artifactDelta': {},
233+
'transferToAgent': 'raw_event_agent',
234+
'escalate': None,
235+
'requestedAuthConfigs': {},
236+
},
237+
'errorCode': '222',
238+
'errorMessage': 'raw_event_error',
239+
'partial': False,
240+
'turnComplete': True,
241+
'interrupted': True,
242+
'branch': 'raw_event_branch',
243+
'customMetadata': None,
244+
'longRunningToolIds': None,
245+
},
246+
}]
247+
248+
MOCK_SESSION_WITH_OVERRIDE_JSON = {
249+
'name': (
250+
'projects/test-project/locations/test-location/'
251+
'reasoningEngines/123/sessions/override'
252+
),
253+
'update_time': '2024-12-12T12:12:12.123456Z',
254+
'user_id': 'user_with_override',
255+
}
256+
165257
MOCK_SESSION = Session(
166258
app_name='123',
167259
user_id='user',
@@ -249,6 +341,8 @@ def _convert_to_object(data):
249341
'artifact_delta',
250342
'custom_metadata',
251343
'requested_auth_configs',
344+
'rawEvent',
345+
'raw_event',
252346
]:
253347
kwargs[key] = value
254348
else:
@@ -683,6 +777,38 @@ async def test_get_session_keeps_events_newer_than_update_time(
683777
)
684778

685779

780+
@pytest.mark.asyncio
781+
@pytest.mark.usefixtures('mock_get_api_client')
782+
@pytest.mark.parametrize(
783+
'mock_event_json',
784+
[MOCK_EVENT_WITH_OVERRIDE_JSON, MOCK_EVENT_WITH_OVERRIDE_JSON_2],
785+
)
786+
async def test_get_session_from_raw_event(
787+
mock_api_client_instance: MockAsyncClient,
788+
mock_event_json,
789+
) -> None:
790+
mock_api_client_instance.session_dict['6'] = MOCK_SESSION_WITH_OVERRIDE_JSON
791+
mock_api_client_instance.event_dict['6'] = (
792+
copy.deepcopy(mock_event_json),
793+
None,
794+
)
795+
session_service = mock_vertex_ai_session_service()
796+
session = await session_service.get_session(
797+
app_name='123', user_id='user_with_override', session_id='6'
798+
)
799+
assert session is not None
800+
assert len(session.events) == 1
801+
event = session.events[0]
802+
assert event.content.parts[0].text == 'raw_event_content'
803+
assert event.actions.transfer_to_agent == 'raw_event_agent'
804+
assert not event.partial
805+
assert event.turn_complete
806+
assert event.interrupted
807+
assert event.branch == 'raw_event_branch'
808+
assert event.error_code == '222'
809+
assert event.error_message == 'raw_event_error'
810+
811+
686812
@pytest.mark.asyncio
687813
@pytest.mark.usefixtures('mock_get_api_client')
688814
async def test_get_session_with_many_events(mock_api_client_instance):
@@ -844,6 +970,36 @@ async def test_append_event():
844970
branch='test_branch',
845971
custom_metadata={'custom': 'data'},
846972
long_running_tool_ids={'tool2'},
973+
input_transcription=genai_types.Transcription(
974+
text='test_input_transcription'
975+
),
976+
output_transcription=genai_types.Transcription(
977+
text='test_output_transcription'
978+
),
979+
model_version='test_model_version',
980+
avg_logprobs=0.5,
981+
logprobs_result=genai_types.LogprobsResult(
982+
chosen_candidates=[
983+
genai_types.LogprobsResultCandidate(
984+
log_probability=0.5,
985+
token='test_token',
986+
token_id=0,
987+
)
988+
]
989+
),
990+
cache_metadata=CacheMetadata(
991+
cache_name='test_cache_name',
992+
fingerprint='test_fingerprint',
993+
contents_count=1,
994+
),
995+
citation_metadata=genai_types.CitationMetadata(
996+
citations=[
997+
genai_types.Citation(
998+
uri='http://test.com',
999+
title='test_title',
1000+
)
1001+
]
1002+
),
8471003
)
8481004

8491005
await session_service.append_event(session_before_append, event_to_append)

0 commit comments

Comments
 (0)