Skip to content

Commit 6ee0362

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Use raw_event field in vertex ai session service for append and list events
PiperOrigin-RevId: 892504394
1 parent 8929907 commit 6ee0362

File tree

2 files changed

+178
-67
lines changed

2 files changed

+178
-67
lines changed

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 23 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
import copy
1718
import datetime
1819
import json
1920
import logging
@@ -266,72 +267,11 @@ async def append_event(self, session: Session, event: Event) -> Event:
266267

267268
reasoning_engine_id = self._get_reasoning_engine_id(session.app_name)
268269

269-
config = {}
270-
if event.content:
271-
config['content'] = event.content.model_dump(
272-
exclude_none=True, mode='json'
273-
)
274-
if event.actions:
275-
config['actions'] = {
276-
'skip_summarization': event.actions.skip_summarization,
277-
'state_delta': event.actions.state_delta,
278-
'artifact_delta': event.actions.artifact_delta,
279-
'transfer_agent': event.actions.transfer_to_agent,
280-
'escalate': event.actions.escalate,
281-
'requested_auth_configs': {
282-
k: json.loads(v.model_dump_json(exclude_none=True, by_alias=True))
283-
for k, v in event.actions.requested_auth_configs.items()
284-
},
285-
# TODO: add requested_tool_confirmations, agent_state once
286-
# they are available in the API.
287-
# Note: compaction is stored via event_metadata.custom_metadata.
288-
}
289-
if event.error_code:
290-
config['error_code'] = event.error_code
291-
if event.error_message:
292-
config['error_message'] = event.error_message
293-
294-
metadata_dict = {
295-
'partial': event.partial,
296-
'turn_complete': event.turn_complete,
297-
'interrupted': event.interrupted,
298-
'branch': event.branch,
299-
'custom_metadata': event.custom_metadata,
300-
'long_running_tool_ids': (
301-
list(event.long_running_tool_ids)
302-
if event.long_running_tool_ids
303-
else None
304-
),
305-
}
306-
if event.grounding_metadata:
307-
metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump(
308-
exclude_none=True, mode='json'
309-
)
310-
# Store compaction data in custom_metadata since the Vertex AI service
311-
# does not yet support the compaction field.
312-
# TODO: Stop writing to custom_metadata once the Vertex AI service
313-
# supports the compaction field natively in EventActions.
314-
if event.actions and event.actions.compaction:
315-
compaction_dict = event.actions.compaction.model_dump(
316-
exclude_none=True, mode='json'
317-
)
318-
_set_internal_custom_metadata(
319-
metadata_dict,
320-
key=_COMPACTION_CUSTOM_METADATA_KEY,
321-
value=compaction_dict,
322-
)
323-
# Store usage_metadata in custom_metadata since the Vertex AI service
324-
# does not persist it in EventMetadata.
325-
if event.usage_metadata:
326-
usage_dict = event.usage_metadata.model_dump(
327-
exclude_none=True, mode='json'
328-
)
329-
_set_internal_custom_metadata(
330-
metadata_dict,
331-
key=_USAGE_METADATA_CUSTOM_METADATA_KEY,
332-
value=usage_dict,
333-
)
334-
config['event_metadata'] = metadata_dict
270+
raw_event_dict = event.model_dump(
271+
exclude_none=True,
272+
mode='json',
273+
by_alias=True,
274+
)
335275

336276
async with self._get_api_client() as api_client:
337277
await api_client.agent_engines.sessions.events.append(
@@ -341,7 +281,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
341281
timestamp=datetime.datetime.fromtimestamp(
342282
event.timestamp, tz=datetime.timezone.utc
343283
),
344-
config=config,
284+
config={'raw_event': raw_event_dict},
345285
)
346286
return event
347287

@@ -390,6 +330,22 @@ def _get_api_client(self) -> vertexai.AsyncClient:
390330

391331
def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
392332
"""Converts an API event object to an Event object."""
333+
# Read event data from raw_event first before falling back to top level
334+
# fields.
335+
raw_event_dict = getattr(
336+
api_event_obj, 'raw_event', getattr(api_event_obj, 'rawEvent', None)
337+
)
338+
if raw_event_dict:
339+
event_dict = copy.deepcopy(raw_event_dict)
340+
timestamp_obj = getattr(api_event_obj, 'timestamp', None)
341+
event_dict.update({
342+
'id': api_event_obj.name.split('/')[-1],
343+
'invocation_id': getattr(api_event_obj, 'invocation_id', None),
344+
'author': getattr(api_event_obj, 'author', None),
345+
'timestamp': timestamp_obj.timestamp() if timestamp_obj else None,
346+
})
347+
return Event.model_validate(event_dict)
348+
393349
actions = getattr(api_event_obj, 'actions', None)
394350
event_metadata = getattr(api_event_obj, 'event_metadata', None)
395351
if event_metadata:

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.adk.events.event import Event
2929
from google.adk.events.event_actions import EventActions
3030
from google.adk.events.event_actions import EventCompaction
31+
from google.adk.models.cache_metadata import CacheMetadata
3132
from google.adk.sessions.base_session_service import GetSessionConfig
3233
from google.adk.sessions.session import Session
3334
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
@@ -162,6 +163,96 @@ def _generate_mock_events_for_session_5(num_events):
162163
MANY_EVENTS_COUNT = 200
163164
MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT)
164165

166+
MOCK_EVENT_WITH_OVERRIDE_JSON = [{
167+
'name': (
168+
'projects/test-project/locations/test-location/'
169+
'reasoningEngines/123/sessions/override/events/1'
170+
),
171+
'invocationId': 'override_invoke',
172+
'author': 'user_with_override',
173+
'timestamp': '2024-12-12T12:12:12.123456Z',
174+
'content': {
175+
'parts': [
176+
{'text': 'top_level_content'},
177+
],
178+
},
179+
'actions': {
180+
'transferToAgent': 'top_level_agent',
181+
},
182+
'eventMetadata': {
183+
'partial': True,
184+
'turnComplete': False,
185+
'interrupted': False,
186+
'branch': 'top_level_branch',
187+
},
188+
'errorCode': '111',
189+
'errorMessage': 'top_level_error',
190+
'rawEvent': {
191+
'invocationId': 'wrong_invocation_id',
192+
'author': 'wrong_author',
193+
'content': {
194+
'parts': [
195+
{'text': 'raw_event_content'},
196+
],
197+
},
198+
'actions': {
199+
'transferToAgent': 'raw_event_agent',
200+
},
201+
'partial': False,
202+
'turnComplete': True,
203+
'interrupted': True,
204+
'branch': 'raw_event_branch',
205+
'errorCode': '222',
206+
'errorMessage': 'raw_event_error',
207+
},
208+
}]
209+
210+
MOCK_EVENT_WITH_OVERRIDE_JSON_2 = [{
211+
'name': (
212+
'projects/test-project/locations/test-location/'
213+
'reasoningEngines/123/sessions/override/events/1'
214+
),
215+
'invocationId': 'override_invoke',
216+
'author': 'user_with_override',
217+
'content': {},
218+
'actions': {},
219+
'timestamp': '2024-12-12T12:12:12.123456Z',
220+
'rawEvent': {
221+
'invocationId': 'wrong_invocation_id',
222+
'author': 'wrong_author',
223+
'content': {
224+
'parts': [
225+
{'text': 'raw_event_content'},
226+
],
227+
},
228+
'actions': {
229+
'skipSummarization': None,
230+
'stateDelta': {},
231+
'artifactDelta': {},
232+
'transferToAgent': 'raw_event_agent',
233+
'escalate': None,
234+
'requestedAuthConfigs': {},
235+
},
236+
'errorCode': '222',
237+
'errorMessage': 'raw_event_error',
238+
'partial': False,
239+
'turnComplete': True,
240+
'interrupted': True,
241+
'branch': 'raw_event_branch',
242+
'customMetadata': None,
243+
'longRunningToolIds': None,
244+
},
245+
}]
246+
247+
MOCK_SESSION_WITH_OVERRIDE_JSON = {
248+
'name': (
249+
'projects/test-project/locations/test-location/'
250+
'reasoningEngines/123/sessions/override'
251+
),
252+
'update_time': '2024-12-12T12:12:12.123456Z',
253+
'user_id': 'user_with_override',
254+
}
255+
165256
MOCK_SESSION = Session(
166257
app_name='123',
167258
user_id='user',
@@ -249,6 +340,8 @@ def _convert_to_object(data):
249340
'artifact_delta',
250341
'custom_metadata',
251342
'requested_auth_configs',
343+
'rawEvent',
344+
'raw_event',
252345
]:
253346
kwargs[key] = value
254347
else:
@@ -680,6 +773,38 @@ async def test_get_session_keeps_events_newer_than_update_time(
680773
)
681774

682775

776+
@pytest.mark.asyncio
777+
@pytest.mark.usefixtures('mock_get_api_client')
778+
@pytest.mark.parametrize(
779+
'mock_event_json',
780+
[MOCK_EVENT_WITH_OVERRIDE_JSON, MOCK_EVENT_WITH_OVERRIDE_JSON_2],
781+
)
782+
async def test_get_session_from_raw_event(
783+
mock_api_client_instance: MockAsyncClient,
784+
mock_event_json,
785+
) -> None:
786+
mock_api_client_instance.session_dict['6'] = MOCK_SESSION_WITH_OVERRIDE_JSON
787+
mock_api_client_instance.event_dict['6'] = (
788+
copy.deepcopy(mock_event_json),
789+
None,
790+
)
791+
session_service = mock_vertex_ai_session_service()
792+
session = await session_service.get_session(
793+
app_name='123', user_id='user_with_override', session_id='6'
794+
)
795+
assert session is not None
796+
assert len(session.events) == 1
797+
event = session.events[0]
798+
assert event.content.parts[0].text == 'raw_event_content'
799+
assert event.actions.transfer_to_agent == 'raw_event_agent'
800+
assert not event.partial
801+
assert event.turn_complete
802+
assert event.interrupted
803+
assert event.branch == 'raw_event_branch'
804+
assert event.error_code == '222'
805+
assert event.error_message == 'raw_event_error'
806+
807+
683808
@pytest.mark.asyncio
684809
@pytest.mark.usefixtures('mock_get_api_client')
685810
async def test_get_session_with_many_events(mock_api_client_instance):
@@ -816,6 +941,36 @@ async def test_append_event():
816941
branch='test_branch',
817942
custom_metadata={'custom': 'data'},
818943
long_running_tool_ids={'tool2'},
944+
input_transcription=genai_types.Transcription(
945+
text='test_input_transcription'
946+
),
947+
output_transcription=genai_types.Transcription(
948+
text='test_output_transcription'
949+
),
950+
model_version='test_model_version',
951+
avg_logprobs=0.5,
952+
logprobs_result=genai_types.LogprobsResult(
953+
chosen_candidates=[
954+
genai_types.LogprobsResultCandidate(
955+
log_probability=0.5,
956+
token='test_token',
957+
token_id=0,
958+
)
959+
]
960+
),
961+
cache_metadata=CacheMetadata(
962+
cache_name='test_cache_name',
963+
fingerprint='test_fingerprint',
964+
contents_count=1,
965+
),
966+
citation_metadata=genai_types.CitationMetadata(
967+
citations=[
968+
genai_types.Citation(
969+
uri='http://test.com',
970+
title='test_title',
971+
)
972+
]
973+
),
819974
)
820975

821976
await session_service.append_event(session_before_append, event_to_append)

0 commit comments

Comments
 (0)