Skip to content

Commit b004da5

Browse files
GWealecopybara-github
authored andcommitted
fix: Allow artifact services to accept dictionary representations of types.Part
This change introduces an `ensure_part` helper function that normalizes input to `types.Part`. This allows `save_artifact` methods in `FileArtifactService`, `GcsArtifactService`, and `InMemoryArtifactService` to accept dictionaries, including those with camelCase keys as used by Agentspace, and convert them into proper `types.Part` instances before saving Close #2886 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 878131948
1 parent 2e434ca commit b004da5

File tree

5 files changed

+175
-10
lines changed

5 files changed

+175
-10
lines changed

src/google/adk/artifacts/base_artifact_service.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
from abc import ABC
1717
from abc import abstractmethod
1818
from datetime import datetime
19+
import logging
1920
from typing import Any
2021
from typing import Optional
22+
from typing import Union
2123

2224
from google.genai import types
2325
from pydantic import alias_generators
2426
from pydantic import BaseModel
2527
from pydantic import ConfigDict
2628
from pydantic import Field
2729

30+
logger = logging.getLogger("google_adk." + __name__)
31+
2832

2933
class ArtifactVersion(BaseModel):
3034
"""Metadata describing a specific version of an artifact."""
@@ -60,6 +64,26 @@ class ArtifactVersion(BaseModel):
6064
)
6165

6266

67+
def ensure_part(artifact: Union[types.Part, dict[str, Any]]) -> types.Part:
68+
"""Normalizes an artifact to a ``types.Part`` instance.
69+
70+
External callers may provide artifacts as
71+
plain dictionaries with camelCase keys (``inlineData``) instead of properly
72+
deserialized ``types.Part`` objects. ``model_validate`` handles both
73+
camelCase and snake_case dictionaries transparently via Pydantic aliases.
74+
75+
Args:
76+
artifact: A ``types.Part`` instance or a dictionary representation.
77+
78+
Returns:
79+
A validated ``types.Part`` instance.
80+
"""
81+
if isinstance(artifact, dict):
82+
logger.debug("Normalizing artifact dict to types.Part: %s", list(artifact))
83+
return types.Part.model_validate(artifact)
84+
return artifact
85+
86+
6387
class BaseArtifactService(ABC):
6488
"""Abstract base class for artifact services."""
6589

@@ -70,7 +94,7 @@ async def save_artifact(
7094
app_name: str,
7195
user_id: str,
7296
filename: str,
73-
artifact: types.Part,
97+
artifact: Union[types.Part, dict[str, Any]],
7498
session_id: Optional[str] = None,
7599
custom_metadata: Optional[dict[str, Any]] = None,
76100
) -> int:
@@ -84,10 +108,12 @@ async def save_artifact(
84108
app_name: The app name.
85109
user_id: The user ID.
86110
filename: The filename of the artifact.
87-
artifact: The artifact to save. If the artifact consists of `file_data`,
88-
the artifact service assumes its content has been uploaded separately,
89-
and this method will associate the `file_data` with the artifact if
90-
necessary.
111+
artifact: The artifact to save. Accepts a ``types.Part`` instance or a
112+
plain dictionary (camelCase or snake_case keys) which will be
113+
normalized via ``ensure_part``. If the artifact consists of
114+
``file_data``, the artifact service assumes its content has been
115+
uploaded separately, and this method will associate the ``file_data``
116+
with the artifact if necessary.
91117
session_id: The session ID. If `None`, the artifact is user-scoped.
92118
custom_metadata: custom metadata to associate with the artifact.
93119

src/google/adk/artifacts/file_artifact_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import shutil
2323
from typing import Any
2424
from typing import Optional
25+
from typing import Union
2526
from urllib.parse import unquote
2627
from urllib.parse import urlparse
2728

@@ -35,6 +36,7 @@
3536
from ..errors.input_validation_error import InputValidationError
3637
from .base_artifact_service import ArtifactVersion
3738
from .base_artifact_service import BaseArtifactService
39+
from .base_artifact_service import ensure_part
3840

3941
logger = logging.getLogger("google_adk." + __name__)
4042

@@ -314,7 +316,7 @@ async def save_artifact(
314316
app_name: str,
315317
user_id: str,
316318
filename: str,
317-
artifact: types.Part,
319+
artifact: Union[types.Part, dict[str, Any]],
318320
session_id: Optional[str] = None,
319321
custom_metadata: Optional[dict[str, Any]] = None,
320322
) -> int:
@@ -339,11 +341,12 @@ def _save_artifact_sync(
339341
self,
340342
user_id: str,
341343
filename: str,
342-
artifact: types.Part,
344+
artifact: Union[types.Part, dict[str, Any]],
343345
session_id: Optional[str],
344346
custom_metadata: Optional[dict[str, Any]],
345347
) -> int:
346348
"""Saves an artifact to disk and returns its version."""
349+
artifact = ensure_part(artifact)
347350
artifact_dir = self._artifact_dir(
348351
user_id=user_id,
349352
session_id=session_id,

src/google/adk/artifacts/gcs_artifact_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
import logging
2828
from typing import Any
2929
from typing import Optional
30+
from typing import Union
3031

3132
from google.genai import types
3233
from typing_extensions import override
3334

3435
from ..errors.input_validation_error import InputValidationError
3536
from .base_artifact_service import ArtifactVersion
3637
from .base_artifact_service import BaseArtifactService
38+
from .base_artifact_service import ensure_part
3739

3840
logger = logging.getLogger("google_adk." + __name__)
3941

@@ -61,7 +63,7 @@ async def save_artifact(
6163
app_name: str,
6264
user_id: str,
6365
filename: str,
64-
artifact: types.Part,
66+
artifact: Union[types.Part, dict[str, Any]],
6567
session_id: Optional[str] = None,
6668
custom_metadata: Optional[dict[str, Any]] = None,
6769
) -> int:
@@ -198,9 +200,10 @@ def _save_artifact(
198200
user_id: str,
199201
session_id: Optional[str],
200202
filename: str,
201-
artifact: types.Part,
203+
artifact: Union[types.Part, dict[str, Any]],
202204
custom_metadata: Optional[dict[str, Any]] = None,
203205
) -> int:
206+
artifact = ensure_part(artifact)
204207
versions = self._list_versions(
205208
app_name=app_name,
206209
user_id=user_id,

src/google/adk/artifacts/in_memory_artifact_service.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
from typing import Any
1919
from typing import Optional
20+
from typing import Union
2021

2122
from google.genai import types
2223
from pydantic import BaseModel
@@ -27,6 +28,7 @@
2728
from ..errors.input_validation_error import InputValidationError
2829
from .base_artifact_service import ArtifactVersion
2930
from .base_artifact_service import BaseArtifactService
31+
from .base_artifact_service import ensure_part
3032

3133
logger = logging.getLogger("google_adk." + __name__)
3234

@@ -99,10 +101,11 @@ async def save_artifact(
99101
app_name: str,
100102
user_id: str,
101103
filename: str,
102-
artifact: types.Part,
104+
artifact: Union[types.Part, dict[str, Any]],
103105
session_id: Optional[str] = None,
104106
custom_metadata: Optional[dict[str, Any]] = None,
105107
) -> int:
108+
artifact = ensure_part(artifact)
106109
path = self._artifact_path(app_name, user_id, filename, session_id)
107110
if path not in self.artifacts:
108111
self.artifacts[path] = []

tests/unittests/artifacts/test_artifact_service.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from urllib.parse import urlparse
3030

3131
from google.adk.artifacts.base_artifact_service import ArtifactVersion
32+
from google.adk.artifacts.base_artifact_service import ensure_part
3233
from google.adk.artifacts.file_artifact_service import FileArtifactService
3334
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
3435
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
@@ -766,3 +767,132 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path):
766767
filename=str(absolute_in_scope),
767768
artifact=part,
768769
)
770+
771+
772+
class TestEnsurePart:
773+
"""Tests for the ensure_part normalization helper."""
774+
775+
def test_returns_part_unchanged(self):
776+
"""A types.Part instance passes through without modification."""
777+
part = types.Part.from_bytes(data=b"hello", mime_type="text/plain")
778+
result = ensure_part(part)
779+
assert result is part
780+
781+
def test_converts_camel_case_dict(self):
782+
"""A camelCase dict (Agentspace format) is converted to types.Part."""
783+
raw = {"inlineData": {"mimeType": "image/png", "data": "dGVzdA=="}}
784+
result = ensure_part(raw)
785+
assert isinstance(result, types.Part)
786+
assert result.inline_data is not None
787+
assert result.inline_data.mime_type == "image/png"
788+
789+
def test_converts_snake_case_dict(self):
790+
"""A snake_case dict is converted to types.Part."""
791+
raw = {"inline_data": {"mime_type": "text/plain", "data": "aGVsbG8="}}
792+
result = ensure_part(raw)
793+
assert isinstance(result, types.Part)
794+
assert result.inline_data is not None
795+
assert result.inline_data.mime_type == "text/plain"
796+
797+
def test_converts_text_dict(self):
798+
"""A dict with 'text' key is converted to types.Part."""
799+
raw = {"text": "hello world"}
800+
result = ensure_part(raw)
801+
assert isinstance(result, types.Part)
802+
assert result.text == "hello world"
803+
804+
805+
@pytest.mark.asyncio
806+
@pytest.mark.parametrize(
807+
"service_type",
808+
[
809+
ArtifactServiceType.IN_MEMORY,
810+
ArtifactServiceType.GCS,
811+
ArtifactServiceType.FILE,
812+
],
813+
)
814+
async def test_save_artifact_with_camel_case_dict(
815+
service_type, artifact_service_factory
816+
):
817+
"""Artifact services accept camelCase dicts (Agentspace format).
818+
819+
Regression test for https://github.com/google/adk-python/issues/2886
820+
"""
821+
artifact_service = artifact_service_factory(service_type)
822+
app_name = "app0"
823+
user_id = "user0"
824+
session_id = "sess0"
825+
filename = "uploaded.png"
826+
827+
# Simulate what Agentspace sends: a plain dict with camelCase keys.
828+
raw_artifact = {
829+
"inlineData": {
830+
"mimeType": "image/png",
831+
"data": "dGVzdF9pbWFnZV9kYXRh",
832+
}
833+
}
834+
835+
version = await artifact_service.save_artifact(
836+
app_name=app_name,
837+
user_id=user_id,
838+
session_id=session_id,
839+
filename=filename,
840+
artifact=raw_artifact,
841+
)
842+
assert version == 0
843+
844+
loaded = await artifact_service.load_artifact(
845+
app_name=app_name,
846+
user_id=user_id,
847+
session_id=session_id,
848+
filename=filename,
849+
)
850+
assert loaded is not None
851+
assert loaded.inline_data is not None
852+
assert loaded.inline_data.mime_type == "image/png"
853+
854+
855+
@pytest.mark.asyncio
856+
@pytest.mark.parametrize(
857+
"service_type",
858+
[
859+
ArtifactServiceType.IN_MEMORY,
860+
ArtifactServiceType.GCS,
861+
ArtifactServiceType.FILE,
862+
],
863+
)
864+
async def test_save_artifact_with_snake_case_dict(
865+
service_type, artifact_service_factory
866+
):
867+
"""Artifact services accept snake_case dicts."""
868+
artifact_service = artifact_service_factory(service_type)
869+
app_name = "app0"
870+
user_id = "user0"
871+
session_id = "sess0"
872+
filename = "uploaded.txt"
873+
874+
raw_artifact = {
875+
"inline_data": {
876+
"mime_type": "text/plain",
877+
"data": "aGVsbG8=",
878+
}
879+
}
880+
881+
version = await artifact_service.save_artifact(
882+
app_name=app_name,
883+
user_id=user_id,
884+
session_id=session_id,
885+
filename=filename,
886+
artifact=raw_artifact,
887+
)
888+
assert version == 0
889+
890+
loaded = await artifact_service.load_artifact(
891+
app_name=app_name,
892+
user_id=user_id,
893+
session_id=session_id,
894+
filename=filename,
895+
)
896+
assert loaded is not None
897+
assert loaded.inline_data is not None
898+
assert loaded.inline_data.mime_type == "text/plain"

0 commit comments

Comments
 (0)