Skip to content

Commit ddabbe7

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Use raw_event to store event data in vertex ai session service
PiperOrigin-RevId: 893180334
1 parent 8850916 commit ddabbe7

File tree

2 files changed

+216
-9
lines changed

2 files changed

+216
-9
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
import copy
1718
import datetime
1819
import json
1920
import logging
@@ -25,6 +26,7 @@
2526

2627
from google.genai import types
2728
from google.genai.errors import ClientError
29+
import pydantic
2830
from typing_extensions import override
2931

3032
if TYPE_CHECKING:
@@ -339,17 +341,41 @@ async def append_event(self, session: Session, event: Event) -> Event:
339341
value=usage_dict,
340342
)
341343
config['event_metadata'] = metadata_dict
344+
config['raw_event'] = event.model_dump(
345+
exclude_none=True,
346+
mode='json',
347+
by_alias=True,
348+
)
342349

350+
# Retry without raw_event if client side validation fails for older SDK
351+
# versions.
343352
async with self._get_api_client() as api_client:
344-
await api_client.agent_engines.sessions.events.append(
345-
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}',
346-
author=event.author,
347-
invocation_id=event.invocation_id,
348-
timestamp=datetime.datetime.fromtimestamp(
349-
event.timestamp, tz=datetime.timezone.utc
350-
),
351-
config=config,
352-
)
353+
try:
354+
await api_client.agent_engines.sessions.events.append(
355+
name=(
356+
f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}'
357+
),
358+
author=event.author,
359+
invocation_id=event.invocation_id,
360+
timestamp=datetime.datetime.fromtimestamp(
361+
event.timestamp, tz=datetime.timezone.utc
362+
),
363+
config=config,
364+
)
365+
except pydantic.ValidationError:
366+
if 'raw_event' in config:
367+
del config['raw_event']
368+
await api_client.agent_engines.sessions.events.append(
369+
name=(
370+
f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}'
371+
),
372+
author=event.author,
373+
invocation_id=event.invocation_id,
374+
timestamp=datetime.datetime.fromtimestamp(
375+
event.timestamp, tz=datetime.timezone.utc
376+
),
377+
config=config,
378+
)
353379
return event
354380

355381
def _get_reasoning_engine_id(self, app_name: str):
@@ -395,8 +421,33 @@ def _get_api_client(self) -> vertexai.AsyncClient:
395421
).aio
396422

397423

424+
def _get_raw_event(api_event_obj: Any) -> dict[str, Any] | None:
425+
"""Extracts raw_event dict from SessionEvent object safely."""
426+
try:
427+
return api_event_obj.raw_event
428+
except AttributeError:
429+
try:
430+
return api_event_obj.rawEvent
431+
except AttributeError:
432+
return None
433+
434+
398435
def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
399436
"""Converts an API event object to an Event object."""
437+
# Read event data from raw_event first before falling back to top level
438+
# fields.
439+
raw_event_dict = _get_raw_event(api_event_obj)
440+
if raw_event_dict:
441+
event_dict = copy.deepcopy(raw_event_dict)
442+
timestamp_obj = getattr(api_event_obj, 'timestamp', None)
443+
event_dict.update({
444+
'id': api_event_obj.name.split('/')[-1],
445+
'invocation_id': getattr(api_event_obj, 'invocation_id', None),
446+
'author': getattr(api_event_obj, 'author', None),
447+
'timestamp': timestamp_obj.timestamp() if timestamp_obj else None,
448+
})
449+
return Event.model_validate(event_dict)
450+
400451
actions = getattr(api_event_obj, 'actions', None)
401452
event_metadata = getattr(api_event_obj, 'event_metadata', None)
402453
if event_metadata:

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.adk.events.event import Event
2929
from google.adk.events.event_actions import EventActions
3030
from google.adk.events.event_actions import EventCompaction
31+
from google.adk.models.cache_metadata import CacheMetadata
3132
from google.adk.sessions.base_session_service import GetSessionConfig
3233
from google.adk.sessions.session import Session
3334
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
@@ -91,6 +92,7 @@
9192
'branch': '',
9293
'long_running_tool_ids': ['tool1'],
9394
},
95+
'raw_event': {},
9496
},
9597
]
9698
MOCK_EVENT_JSON_2 = [
@@ -162,6 +164,96 @@ def _generate_mock_events_for_session_5(num_events):
162164
MANY_EVENTS_COUNT = 200
163165
MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT)
164166

167+
MOCK_EVENT_WITH_OVERRIDE_JSON = [{
168+
'name': (
169+
'projects/test-project/locations/test-location/'
170+
'reasoningEngines/123/sessions/override/events/1'
171+
),
172+
'invocationId': 'override_invoke',
173+
'author': 'user_with_override',
174+
'timestamp': '2024-12-12T12:12:12.123456Z',
175+
'content': {
176+
'parts': [
177+
{'text': 'top_level_content'},
178+
],
179+
},
180+
'actions': {
181+
'transferToAgent': 'top_level_agent',
182+
},
183+
'eventMetadata': {
184+
'partial': True,
185+
'turnComplete': False,
186+
'interrupted': False,
187+
'branch': 'top_level_branch',
188+
},
189+
'errorCode': '111',
190+
'errorMessage': 'top_level_error',
191+
'rawEvent': {
192+
'invocationId': 'wrong_invocation_id',
193+
'author': 'wrong_author',
194+
'content': {
195+
'parts': [
196+
{'text': 'raw_event_content'},
197+
],
198+
},
199+
'actions': {
200+
'transferToAgent': 'raw_event_agent',
201+
},
202+
'partial': False,
203+
'turnComplete': True,
204+
'interrupted': True,
205+
'branch': 'raw_event_branch',
206+
'errorCode': '222',
207+
'errorMessage': 'raw_event_error',
208+
},
209+
}]
210+
211+
MOCK_EVENT_WITH_OVERRIDE_JSON_2 = [{
212+
'name': (
213+
'projects/test-project/locations/test-location/'
214+
'reasoningEngines/123/sessions/override/events/1'
215+
),
216+
'invocationId': 'override_invoke',
217+
'author': 'user_with_override',
218+
'content': {},
219+
'actions': {},
220+
'timestamp': '2024-12-12T12:12:12.123456Z',
221+
'rawEvent': {
222+
'invocationId': 'wrong_invocation_id',
223+
'author': 'wrong_author',
224+
'content': {
225+
'parts': [
226+
{'text': 'raw_event_content'},
227+
],
228+
},
229+
'actions': {
230+
'skipSummarization': None,
231+
'stateDelta': {},
232+
'artifactDelta': {},
233+
'transferToAgent': 'raw_event_agent',
234+
'escalate': None,
235+
'requestedAuthConfigs': {},
236+
},
237+
'errorCode': '222',
238+
'errorMessage': 'raw_event_error',
239+
'partial': False,
240+
'turnComplete': True,
241+
'interrupted': True,
242+
'branch': 'raw_event_branch',
243+
'customMetadata': None,
244+
'longRunningToolIds': None,
245+
},
246+
}]
247+
248+
MOCK_SESSION_WITH_OVERRIDE_JSON = {
249+
'name': (
250+
'projects/test-project/locations/test-location/'
251+
'reasoningEngines/123/sessions/override'
252+
),
253+
'update_time': '2024-12-12T12:12:12.123456Z',
254+
'user_id': 'user_with_override',
255+
}
256+
165257
MOCK_SESSION = Session(
166258
app_name='123',
167259
user_id='user',
@@ -249,6 +341,8 @@ def _convert_to_object(data):
249341
'artifact_delta',
250342
'custom_metadata',
251343
'requested_auth_configs',
344+
'rawEvent',
345+
'raw_event',
252346
]:
253347
kwargs[key] = value
254348
else:
@@ -680,6 +774,38 @@ async def test_get_session_keeps_events_newer_than_update_time(
680774
)
681775

682776

777+
@pytest.mark.asyncio
778+
@pytest.mark.usefixtures('mock_get_api_client')
779+
@pytest.mark.parametrize(
780+
'mock_event_json',
781+
[MOCK_EVENT_WITH_OVERRIDE_JSON, MOCK_EVENT_WITH_OVERRIDE_JSON_2],
782+
)
783+
async def test_get_session_from_raw_event(
784+
mock_api_client_instance: MockAsyncClient,
785+
mock_event_json,
786+
) -> None:
787+
mock_api_client_instance.session_dict['6'] = MOCK_SESSION_WITH_OVERRIDE_JSON
788+
mock_api_client_instance.event_dict['6'] = (
789+
copy.deepcopy(mock_event_json),
790+
None,
791+
)
792+
session_service = mock_vertex_ai_session_service()
793+
session = await session_service.get_session(
794+
app_name='123', user_id='user_with_override', session_id='6'
795+
)
796+
assert session is not None
797+
assert len(session.events) == 1
798+
event = session.events[0]
799+
assert event.content.parts[0].text == 'raw_event_content'
800+
assert event.actions.transfer_to_agent == 'raw_event_agent'
801+
assert not event.partial
802+
assert event.turn_complete
803+
assert event.interrupted
804+
assert event.branch == 'raw_event_branch'
805+
assert event.error_code == '222'
806+
assert event.error_message == 'raw_event_error'
807+
808+
683809
@pytest.mark.asyncio
684810
@pytest.mark.usefixtures('mock_get_api_client')
685811
async def test_get_session_with_many_events(mock_api_client_instance):
@@ -830,6 +956,36 @@ async def test_append_event():
830956
branch='test_branch',
831957
custom_metadata={'custom': 'data'},
832958
long_running_tool_ids={'tool2'},
959+
input_transcription=genai_types.Transcription(
960+
text='test_input_transcription'
961+
),
962+
output_transcription=genai_types.Transcription(
963+
text='test_output_transcription'
964+
),
965+
model_version='test_model_version',
966+
avg_logprobs=0.5,
967+
logprobs_result=genai_types.LogprobsResult(
968+
chosen_candidates=[
969+
genai_types.LogprobsResultCandidate(
970+
log_probability=0.5,
971+
token='test_token',
972+
token_id=0,
973+
)
974+
]
975+
),
976+
cache_metadata=CacheMetadata(
977+
cache_name='test_cache_name',
978+
fingerprint='test_fingerprint',
979+
contents_count=1,
980+
),
981+
citation_metadata=genai_types.CitationMetadata(
982+
citations=[
983+
genai_types.Citation(
984+
uri='http://test.com',
985+
title='test_title',
986+
)
987+
]
988+
),
833989
)
834990

835991
await session_service.append_event(session_before_append, event_to_append)

0 commit comments

Comments
 (0)