|
28 | 28 | from google.adk.events.event import Event |
29 | 29 | from google.adk.events.event_actions import EventActions |
30 | 30 | from google.adk.events.event_actions import EventCompaction |
| 31 | +from google.adk.models.cache_metadata import CacheMetadata |
31 | 32 | from google.adk.sessions.base_session_service import GetSessionConfig |
32 | 33 | from google.adk.sessions.session import Session |
33 | 34 | from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService |
|
91 | 92 | 'branch': '', |
92 | 93 | 'long_running_tool_ids': ['tool1'], |
93 | 94 | }, |
| 95 | + 'raw_event': {}, |
94 | 96 | }, |
95 | 97 | ] |
96 | 98 | MOCK_EVENT_JSON_2 = [ |
@@ -162,6 +164,96 @@ def _generate_mock_events_for_session_5(num_events): |
162 | 164 | MANY_EVENTS_COUNT = 200 |
163 | 165 | MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT) |
164 | 166 |
|
| 167 | +MOCK_EVENT_WITH_OVERRIDE_JSON = [{ |
| 168 | + 'name': ( |
| 169 | + 'projects/test-project/locations/test-location/' |
| 170 | + 'reasoningEngines/123/sessions/override/events/1' |
| 171 | + ), |
| 172 | + 'invocationId': 'override_invoke', |
| 173 | + 'author': 'user_with_override', |
| 174 | + 'timestamp': '2024-12-12T12:12:12.123456Z', |
| 175 | + 'content': { |
| 176 | + 'parts': [ |
| 177 | + {'text': 'top_level_content'}, |
| 178 | + ], |
| 179 | + }, |
| 180 | + 'actions': { |
| 181 | + 'transferToAgent': 'top_level_agent', |
| 182 | + }, |
| 183 | + 'eventMetadata': { |
| 184 | + 'partial': True, |
| 185 | + 'turnComplete': False, |
| 186 | + 'interrupted': False, |
| 187 | + 'branch': 'top_level_branch', |
| 188 | + }, |
| 189 | + 'errorCode': '111', |
| 190 | + 'errorMessage': 'top_level_error', |
| 191 | + 'rawEvent': { |
| 192 | + 'invocationId': 'wrong_invocation_id', |
| 193 | + 'author': 'wrong_author', |
| 194 | + 'content': { |
| 195 | + 'parts': [ |
| 196 | + {'text': 'raw_event_content'}, |
| 197 | + ], |
| 198 | + }, |
| 199 | + 'actions': { |
| 200 | + 'transferToAgent': 'raw_event_agent', |
| 201 | + }, |
| 202 | + 'partial': False, |
| 203 | + 'turnComplete': True, |
| 204 | + 'interrupted': True, |
| 205 | + 'branch': 'raw_event_branch', |
| 206 | + 'errorCode': '222', |
| 207 | + 'errorMessage': 'raw_event_error', |
| 208 | + }, |
| 209 | +}] |
| 210 | + |
| 211 | +MOCK_EVENT_WITH_OVERRIDE_JSON_2 = [{ |
| 212 | + 'name': ( |
| 213 | + 'projects/test-project/locations/test-location/' |
| 214 | + 'reasoningEngines/123/sessions/override/events/1' |
| 215 | + ), |
| 216 | + 'invocationId': 'override_invoke', |
| 217 | + 'author': 'user_with_override', |
| 218 | + 'content': {}, |
| 219 | + 'actions': {}, |
| 220 | + 'timestamp': '2024-12-12T12:12:12.123456Z', |
| 221 | + 'rawEvent': { |
| 222 | + 'invocationId': 'wrong_invocation_id', |
| 223 | + 'author': 'wrong_author', |
| 224 | + 'content': { |
| 225 | + 'parts': [ |
| 226 | + {'text': 'raw_event_content'}, |
| 227 | + ], |
| 228 | + }, |
| 229 | + 'actions': { |
| 230 | + 'skipSummarization': None, |
| 231 | + 'stateDelta': {}, |
| 232 | + 'artifactDelta': {}, |
| 233 | + 'transferToAgent': 'raw_event_agent', |
| 234 | + 'escalate': None, |
| 235 | + 'requestedAuthConfigs': {}, |
| 236 | + }, |
| 237 | + 'errorCode': '222', |
| 238 | + 'errorMessage': 'raw_event_error', |
| 239 | + 'partial': False, |
| 240 | + 'turnComplete': True, |
| 241 | + 'interrupted': True, |
| 242 | + 'branch': 'raw_event_branch', |
| 243 | + 'customMetadata': None, |
| 244 | + 'longRunningToolIds': None, |
| 245 | + }, |
| 246 | +}] |
| 247 | + |
| 248 | +MOCK_SESSION_WITH_OVERRIDE_JSON = { |
| 249 | + 'name': ( |
| 250 | + 'projects/test-project/locations/test-location/' |
| 251 | + 'reasoningEngines/123/sessions/override' |
| 252 | + ), |
| 253 | + 'update_time': '2024-12-12T12:12:12.123456Z', |
| 254 | + 'user_id': 'user_with_override', |
| 255 | +} |
| 256 | + |
165 | 257 | MOCK_SESSION = Session( |
166 | 258 | app_name='123', |
167 | 259 | user_id='user', |
@@ -249,6 +341,8 @@ def _convert_to_object(data): |
249 | 341 | 'artifact_delta', |
250 | 342 | 'custom_metadata', |
251 | 343 | 'requested_auth_configs', |
| 344 | + 'rawEvent', |
| 345 | + 'raw_event', |
252 | 346 | ]: |
253 | 347 | kwargs[key] = value |
254 | 348 | else: |
@@ -680,6 +774,38 @@ async def test_get_session_keeps_events_newer_than_update_time( |
680 | 774 | ) |
681 | 775 |
|
682 | 776 |
|
| 777 | +@pytest.mark.asyncio |
| 778 | +@pytest.mark.usefixtures('mock_get_api_client') |
| 779 | +@pytest.mark.parametrize( |
| 780 | + 'mock_event_json', |
| 781 | + [MOCK_EVENT_WITH_OVERRIDE_JSON, MOCK_EVENT_WITH_OVERRIDE_JSON_2], |
| 782 | +) |
| 783 | +async def test_get_session_from_raw_event( |
| 784 | + mock_api_client_instance: MockAsyncClient, |
| 785 | + mock_event_json, |
| 786 | +) -> None: |
| 787 | + mock_api_client_instance.session_dict['6'] = MOCK_SESSION_WITH_OVERRIDE_JSON |
| 788 | + mock_api_client_instance.event_dict['6'] = ( |
| 789 | + copy.deepcopy(mock_event_json), |
| 790 | + None, |
| 791 | + ) |
| 792 | + session_service = mock_vertex_ai_session_service() |
| 793 | + session = await session_service.get_session( |
| 794 | + app_name='123', user_id='user_with_override', session_id='6' |
| 795 | + ) |
| 796 | + assert session is not None |
| 797 | + assert len(session.events) == 1 |
| 798 | + event = session.events[0] |
| 799 | + assert event.content.parts[0].text == 'raw_event_content' |
| 800 | + assert event.actions.transfer_to_agent == 'raw_event_agent' |
| 801 | + assert not event.partial |
| 802 | + assert event.turn_complete |
| 803 | + assert event.interrupted |
| 804 | + assert event.branch == 'raw_event_branch' |
| 805 | + assert event.error_code == '222' |
| 806 | + assert event.error_message == 'raw_event_error' |
| 807 | + |
| 808 | + |
683 | 809 | @pytest.mark.asyncio |
684 | 810 | @pytest.mark.usefixtures('mock_get_api_client') |
685 | 811 | async def test_get_session_with_many_events(mock_api_client_instance): |
@@ -830,6 +956,36 @@ async def test_append_event(): |
830 | 956 | branch='test_branch', |
831 | 957 | custom_metadata={'custom': 'data'}, |
832 | 958 | long_running_tool_ids={'tool2'}, |
| 959 | + input_transcription=genai_types.Transcription( |
| 960 | + text='test_input_transcription' |
| 961 | + ), |
| 962 | + output_transcription=genai_types.Transcription( |
| 963 | + text='test_output_transcription' |
| 964 | + ), |
| 965 | + model_version='test_model_version', |
| 966 | + avg_logprobs=0.5, |
| 967 | + logprobs_result=genai_types.LogprobsResult( |
| 968 | + chosen_candidates=[ |
| 969 | + genai_types.LogprobsResultCandidate( |
| 970 | + log_probability=0.5, |
| 971 | + token='test_token', |
| 972 | + token_id=0, |
| 973 | + ) |
| 974 | + ] |
| 975 | + ), |
| 976 | + cache_metadata=CacheMetadata( |
| 977 | + cache_name='test_cache_name', |
| 978 | + fingerprint='test_fingerprint', |
| 979 | + contents_count=1, |
| 980 | + ), |
| 981 | + citation_metadata=genai_types.CitationMetadata( |
| 982 | + citations=[ |
| 983 | + genai_types.Citation( |
| 984 | + uri='http://test.com', |
| 985 | + title='test_title', |
| 986 | + ) |
| 987 | + ] |
| 988 | + ), |
833 | 989 | ) |
834 | 990 |
|
835 | 991 | await session_service.append_event(session_before_append, event_to_append) |
|
0 commit comments