Skip to content

Commit c3d50db

Browse files
GWealecopybara-github
authored andcommitted
refactor: Use artifact_service.load_artifact during rewind
During session rewind, when restoring an artifact to a previous version, the runner now uses `artifact_service.load_artifact` to fetch the artifact's content. Previously, it would construct a `types.Part` with `file_data` using the artifact URI. This change is necessary because some artifact services, such as those backed by GCS or local files, do not support saving `types.Part` objects that contain `file_data` Close #4932 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 905223401
1 parent 5ce33b9 commit c3d50db

2 files changed

Lines changed: 130 additions & 4 deletions

File tree

src/google/adk/runners.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import warnings
3030

3131
from google.adk.apps.compaction import _run_compaction_for_sliding_window
32-
from google.adk.artifacts import artifact_util
3332
from google.genai import types
3433

3534
from .agents.base_agent import BaseAgent
@@ -754,15 +753,27 @@ async def _compute_artifact_delta_for_rewind(
754753
)
755754
else:
756755
# Artifact version changed after rewind point. Restore to version at
757-
# rewind point.
758-
artifact_uri = artifact_util.get_artifact_uri(
756+
# rewind point by loading the actual data via the artifact service.
757+
artifact = await self.artifact_service.load_artifact(
759758
app_name=self.app_name,
760759
user_id=session.user_id,
761760
session_id=session.id,
762761
filename=filename,
763762
version=vt,
764763
)
765-
artifact = types.Part(file_data=types.FileData(file_uri=artifact_uri))
764+
if artifact is None:
765+
logger.warning(
766+
'Artifact %s version %d not found during rewind for'
767+
' session %s. Replacing with empty data.',
768+
filename,
769+
vt,
770+
session.id,
771+
)
772+
artifact = types.Part(
773+
inline_data=types.Blob(
774+
mime_type='application/octet-stream', data=b''
775+
)
776+
)
766777
await self.artifact_service.save_artifact(
767778
app_name=self.app_name,
768779
user_id=session.user_id,

tests/unittests/runners/test_runner_rewind.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
"""Tests for runner.rewind_async."""
1616

17+
from typing import Any
18+
from typing import Optional
19+
from typing import Union
20+
1721
from google.adk.agents.base_agent import BaseAgent
22+
from google.adk.artifacts.base_artifact_service import ensure_part
1823
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
1924
from google.adk.events.event import Event
2025
from google.adk.events.event import EventActions
@@ -24,6 +29,34 @@
2429
import pytest
2530

2631

32+
class _NoFileDataArtifactService(InMemoryArtifactService):
33+
"""Artifact service that rejects file_data parts, like GCS/File services."""
34+
35+
async def save_artifact(
36+
self,
37+
*,
38+
app_name: str,
39+
user_id: str,
40+
filename: str,
41+
artifact: Union[types.Part, dict[str, Any]],
42+
session_id: Optional[str] = None,
43+
custom_metadata: Optional[dict[str, Any]] = None,
44+
) -> int:
45+
artifact = ensure_part(artifact)
46+
if artifact.file_data:
47+
raise NotImplementedError(
48+
"Saving artifact with file_data is not supported."
49+
)
50+
return await super().save_artifact(
51+
app_name=app_name,
52+
user_id=user_id,
53+
filename=filename,
54+
artifact=artifact,
55+
session_id=session_id,
56+
custom_metadata=custom_metadata,
57+
)
58+
59+
2760
class TestRunnerRewind:
2861
"""Tests for runner.rewind_async."""
2962

@@ -246,3 +279,85 @@ async def test_rewind_async_not_first_invocation(self):
246279
session_id=session_id,
247280
filename="f2",
248281
) == types.Part.from_text(text="f2v0")
282+
283+
284+
class TestRunnerRewindNoFileData:
285+
"""Tests that rewind works with artifact services that reject file_data."""
286+
287+
@pytest.mark.asyncio
288+
async def test_rewind_uses_load_artifact_not_file_data(self):
289+
"""Rewind must not construct file_data parts for artifact restoration.
290+
291+
GCS and File artifact services reject file_data parts. The runner
292+
should use load_artifact to get inline_data instead.
293+
"""
294+
root_agent = BaseAgent(name="test_agent")
295+
session_service = InMemorySessionService()
296+
artifact_service = _NoFileDataArtifactService()
297+
runner = Runner(
298+
app_name="test_app",
299+
agent=root_agent,
300+
session_service=session_service,
301+
artifact_service=artifact_service,
302+
)
303+
user_id = "test_user"
304+
session_id = "test_session"
305+
306+
session = await runner.session_service.create_session(
307+
app_name=runner.app_name, user_id=user_id, session_id=session_id
308+
)
309+
310+
# invocation1: create artifact f1 v0
311+
await runner.artifact_service.save_artifact(
312+
app_name=runner.app_name,
313+
user_id=user_id,
314+
session_id=session_id,
315+
filename="f1",
316+
artifact=types.Part.from_text(text="f1v0"),
317+
)
318+
event1 = Event(
319+
invocation_id="invocation1",
320+
author="agent",
321+
content=types.Content(parts=[types.Part.from_text(text="e1")]),
322+
actions=EventActions(
323+
state_delta={"k1": "v1"}, artifact_delta={"f1": 0}
324+
),
325+
)
326+
await runner.session_service.append_event(session=session, event=event1)
327+
328+
# invocation2: update artifact f1 to v1
329+
await runner.artifact_service.save_artifact(
330+
app_name=runner.app_name,
331+
user_id=user_id,
332+
session_id=session_id,
333+
filename="f1",
334+
artifact=types.Part.from_text(text="f1v1"),
335+
)
336+
event2 = Event(
337+
invocation_id="invocation2",
338+
author="agent",
339+
content=types.Content(parts=[types.Part.from_text(text="e2")]),
340+
actions=EventActions(artifact_delta={"f1": 1}),
341+
)
342+
await runner.session_service.append_event(session=session, event=event2)
343+
344+
session = await runner.session_service.get_session(
345+
app_name=runner.app_name, user_id=user_id, session_id=session_id
346+
)
347+
348+
# Rewind before invocation2 — this would raise NotImplementedError
349+
# with the old code that constructed file_data parts.
350+
await runner.rewind_async(
351+
user_id=user_id,
352+
session_id=session_id,
353+
rewind_before_invocation_id="invocation2",
354+
)
355+
356+
# f1 should be restored to v0 content
357+
restored = await runner.artifact_service.load_artifact(
358+
app_name=runner.app_name,
359+
user_id=user_id,
360+
session_id=session_id,
361+
filename="f1",
362+
)
363+
assert restored == types.Part.from_text(text="f1v0")

0 commit comments

Comments
 (0)