|
14 | 14 |
|
15 | 15 | """Tests for runner.rewind_async.""" |
16 | 16 |
|
| 17 | +from typing import Any |
| 18 | +from typing import Optional |
| 19 | +from typing import Union |
| 20 | + |
17 | 21 | from google.adk.agents.base_agent import BaseAgent |
| 22 | +from google.adk.artifacts.base_artifact_service import ensure_part |
18 | 23 | from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService |
19 | 24 | from google.adk.events.event import Event |
20 | 25 | from google.adk.events.event import EventActions |
|
24 | 29 | import pytest |
25 | 30 |
|
26 | 31 |
|
| 32 | +class _NoFileDataArtifactService(InMemoryArtifactService): |
| 33 | + """Artifact service that rejects file_data parts, like GCS/File services.""" |
| 34 | + |
| 35 | + async def save_artifact( |
| 36 | + self, |
| 37 | + *, |
| 38 | + app_name: str, |
| 39 | + user_id: str, |
| 40 | + filename: str, |
| 41 | + artifact: Union[types.Part, dict[str, Any]], |
| 42 | + session_id: Optional[str] = None, |
| 43 | + custom_metadata: Optional[dict[str, Any]] = None, |
| 44 | + ) -> int: |
| 45 | + artifact = ensure_part(artifact) |
| 46 | + if artifact.file_data: |
| 47 | + raise NotImplementedError( |
| 48 | + "Saving artifact with file_data is not supported." |
| 49 | + ) |
| 50 | + return await super().save_artifact( |
| 51 | + app_name=app_name, |
| 52 | + user_id=user_id, |
| 53 | + filename=filename, |
| 54 | + artifact=artifact, |
| 55 | + session_id=session_id, |
| 56 | + custom_metadata=custom_metadata, |
| 57 | + ) |
| 58 | + |
| 59 | + |
27 | 60 | class TestRunnerRewind: |
28 | 61 | """Tests for runner.rewind_async.""" |
29 | 62 |
|
@@ -246,3 +279,85 @@ async def test_rewind_async_not_first_invocation(self): |
246 | 279 | session_id=session_id, |
247 | 280 | filename="f2", |
248 | 281 | ) == types.Part.from_text(text="f2v0") |
| 282 | + |
| 283 | + |
| 284 | +class TestRunnerRewindNoFileData: |
| 285 | + """Tests that rewind works with artifact services that reject file_data.""" |
| 286 | + |
| 287 | + @pytest.mark.asyncio |
| 288 | + async def test_rewind_uses_load_artifact_not_file_data(self): |
| 289 | + """Rewind must not construct file_data parts for artifact restoration. |
| 290 | +
|
| 291 | + GCS and File artifact services reject file_data parts. The runner |
| 292 | + should use load_artifact to get inline_data instead. |
| 293 | + """ |
| 294 | + root_agent = BaseAgent(name="test_agent") |
| 295 | + session_service = InMemorySessionService() |
| 296 | + artifact_service = _NoFileDataArtifactService() |
| 297 | + runner = Runner( |
| 298 | + app_name="test_app", |
| 299 | + agent=root_agent, |
| 300 | + session_service=session_service, |
| 301 | + artifact_service=artifact_service, |
| 302 | + ) |
| 303 | + user_id = "test_user" |
| 304 | + session_id = "test_session" |
| 305 | + |
| 306 | + session = await runner.session_service.create_session( |
| 307 | + app_name=runner.app_name, user_id=user_id, session_id=session_id |
| 308 | + ) |
| 309 | + |
| 310 | + # invocation1: create artifact f1 v0 |
| 311 | + await runner.artifact_service.save_artifact( |
| 312 | + app_name=runner.app_name, |
| 313 | + user_id=user_id, |
| 314 | + session_id=session_id, |
| 315 | + filename="f1", |
| 316 | + artifact=types.Part.from_text(text="f1v0"), |
| 317 | + ) |
| 318 | + event1 = Event( |
| 319 | + invocation_id="invocation1", |
| 320 | + author="agent", |
| 321 | + content=types.Content(parts=[types.Part.from_text(text="e1")]), |
| 322 | + actions=EventActions( |
| 323 | + state_delta={"k1": "v1"}, artifact_delta={"f1": 0} |
| 324 | + ), |
| 325 | + ) |
| 326 | + await runner.session_service.append_event(session=session, event=event1) |
| 327 | + |
| 328 | + # invocation2: update artifact f1 to v1 |
| 329 | + await runner.artifact_service.save_artifact( |
| 330 | + app_name=runner.app_name, |
| 331 | + user_id=user_id, |
| 332 | + session_id=session_id, |
| 333 | + filename="f1", |
| 334 | + artifact=types.Part.from_text(text="f1v1"), |
| 335 | + ) |
| 336 | + event2 = Event( |
| 337 | + invocation_id="invocation2", |
| 338 | + author="agent", |
| 339 | + content=types.Content(parts=[types.Part.from_text(text="e2")]), |
| 340 | + actions=EventActions(artifact_delta={"f1": 1}), |
| 341 | + ) |
| 342 | + await runner.session_service.append_event(session=session, event=event2) |
| 343 | + |
| 344 | + session = await runner.session_service.get_session( |
| 345 | + app_name=runner.app_name, user_id=user_id, session_id=session_id |
| 346 | + ) |
| 347 | + |
| 348 | + # Rewind before invocation2 — this would raise NotImplementedError |
| 349 | + # with the old code that constructed file_data parts. |
| 350 | + await runner.rewind_async( |
| 351 | + user_id=user_id, |
| 352 | + session_id=session_id, |
| 353 | + rewind_before_invocation_id="invocation2", |
| 354 | + ) |
| 355 | + |
| 356 | + # f1 should be restored to v0 content |
| 357 | + restored = await runner.artifact_service.load_artifact( |
| 358 | + app_name=runner.app_name, |
| 359 | + user_id=user_id, |
| 360 | + session_id=session_id, |
| 361 | + filename="f1", |
| 362 | + ) |
| 363 | + assert restored == types.Part.from_text(text="f1v0") |
0 commit comments