1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """An artifact service implementation using Amazon S3.
16-
17- The object key format depends on whether the filename has a user namespace:
18- - For files with user namespace (starting with "user:"):
19- {app_name}/{user_id}/user/{filename}/{version}
20- - For regular session-scoped files:
21- {app_name}/{user_id}/{session_id}/{filename}/{version}
22-
23- Uses aioboto3 for native async I/O and atomic versioning via
24- S3's ``IfNoneMatch`` condition to prevent race conditions.
25-
26- Install S3 support with::
27-
28- pip install google-adk-community[s3]
29- """
15+ """Artifact service implementation using Amazon S3."""
3016from __future__ import annotations
3117
3218import asyncio
4531
4632
4733class S3ArtifactService (BaseArtifactService ):
48- """An artifact service implementation using Amazon S3.
49-
50- Uses ``aioboto3`` for native async I/O instead of wrapping synchronous
51- calls with ``asyncio.to_thread``. Artifact saves are atomic: a
52- ``IfNoneMatch="*"`` condition on ``put_object`` prevents race conditions
53- when two writers try to create the same version concurrently.
54-
55- Args:
56- bucket_name: The name of the S3 bucket to use.
57- aws_configs: Extra keyword arguments forwarded to
58- ``aioboto3.Session().client("s3", ...)``. Use this to pass
59- ``region_name``, ``endpoint_url`` (for MinIO / Spaces), etc.
60- save_max_retries: Maximum retries on version conflict.
61- ``-1`` means retry indefinitely.
62- """
34+ """An S3-backed implementation of the artifact service."""
6335
6436 def __init__ (
6537 self ,
6638 bucket_name : str ,
6739 aws_configs : Optional [dict [str , Any ]] = None ,
6840 save_max_retries : int = - 1 ,
6941 ):
42+ """Initializes the S3 artifact service.
43+
44+ Args:
45+ bucket_name: The name of the S3 bucket to use.
46+ aws_configs: Extra kwargs forwarded to the aioboto3 S3 client.
47+ Use this to pass region_name, endpoint_url (for MinIO), etc.
48+ save_max_retries: Maximum retries on version conflict. -1 means
49+ retry indefinitely.
50+ """
7051 try :
7152 import aioboto3 # noqa: F401
7253 except ImportError as exc :
@@ -80,9 +61,7 @@ def __init__(
8061 self .save_max_retries = save_max_retries
8162 self ._session = None
8263
83- # ------------------------------------------------------------------ #
84- # S3 client helpers
85- # ------------------------------------------------------------------ #
64+
8665
8766 async def _get_session (self ):
8867 import aioboto3
@@ -99,20 +78,16 @@ async def _client(self):
9978 ) as s3 :
10079 yield s3
10180
102- # ------------------------------------------------------------------ #
103- # Metadata serialisation
104- # ------------------------------------------------------------------ #
105-
10681 @staticmethod
10782 def _flatten_metadata (metadata : Optional [dict [str , Any ]]) -> dict [str , str ]:
108- """JSON-encode metadata values for S3 user-metadata (strings only) ."""
83+ """JSON-encode metadata values for S3 user-metadata."""
10984 if not metadata :
11085 return {}
11186 return {str (k ): json .dumps (v ) for k , v in metadata .items ()}
11287
11388 @staticmethod
11489 def _unflatten_metadata (metadata : Optional [dict [str , str ]]) -> dict [str , Any ]:
115- """Decode JSON metadata back to Python objects."""
90+ """Decode JSON metadata back to native Python objects."""
11691 results : dict [str , Any ] = {}
11792 for k , v in (metadata or {}).items ():
11893 try :
@@ -124,10 +99,6 @@ def _unflatten_metadata(metadata: Optional[dict[str, str]]) -> dict[str, Any]:
12499 results [k ] = v
125100 return results
126101
127- # ------------------------------------------------------------------ #
128- # Key helpers
129- # ------------------------------------------------------------------ #
130-
131102 @staticmethod
132103 def _file_has_user_namespace (filename : str ) -> bool :
133104 return filename .startswith ("user:" )
@@ -160,10 +131,6 @@ def _get_blob_name(
160131 f"/{ version } "
161132 )
162133
163- # ------------------------------------------------------------------ #
164- # Public API
165- # ------------------------------------------------------------------ #
166-
167134 @override
168135 async def save_artifact (
169136 self ,
@@ -175,12 +142,7 @@ async def save_artifact(
175142 session_id : Optional [str ] = None ,
176143 custom_metadata : Optional [dict [str , Any ]] = None ,
177144 ) -> int :
178- """Save an artifact with atomic versioning via ``IfNoneMatch``.
179-
180- If two concurrent callers race to create the same version, S3
181- will reject the second ``put_object`` with a ``PreconditionFailed``
182- error and this method will transparently retry.
183- """
145+ """Save an artifact with atomic versioning via IfNoneMatch."""
184146 from botocore .exceptions import ClientError
185147
186148 if self .save_max_retries < 0 :
@@ -200,7 +162,6 @@ async def save_artifact(
200162 app_name , user_id , session_id , filename , version
201163 )
202164
203- # Prepare data and content type
204165 if artifact .inline_data :
205166 body = artifact .inline_data .data
206167 content_type = (
@@ -257,7 +218,7 @@ async def load_artifact(
257218 session_id : Optional [str ] = None ,
258219 version : Optional [int ] = None ,
259220 ) -> Optional [types .Part ]:
260- """Load a specific version (or latest) of an artifact from S3 ."""
221+ """Load a specific version of an artifact, or the latest ."""
261222 from botocore .exceptions import ClientError
262223
263224 if version is None :
@@ -299,7 +260,7 @@ async def load_artifact(
299260 async def list_artifact_keys (
300261 self , * , app_name : str , user_id : str , session_id : Optional [str ] = None
301262 ) -> list [str ]:
302- """List all artifact keys for a user, optionally scoped to a session."""
263+ """List all artifact keys for a user, optionally filtered by session."""
303264 keys : set [str ] = set ()
304265 prefixes = [
305266 f"{ app_name } /{ user_id } /{ session_id } /" if session_id else None ,
@@ -313,11 +274,9 @@ async def list_artifact_keys(
313274 ):
314275 for obj in page .get ("Contents" , []):
315276 relative = obj ["Key" ][len (prefix ):]
316- # relative is "{filename}/{version}" — strip version part
317277 parts = relative .rsplit ("/" , 1 )
318278 if len (parts ) >= 2 :
319279 raw_filename = parts [0 ]
320- # Re-add "user:" prefix for user-scoped artifacts
321280 if prefix .endswith ("/user/" ):
322281 keys .add (f"user:{ raw_filename } " )
323282 else :
@@ -333,7 +292,7 @@ async def delete_artifact(
333292 filename : str ,
334293 session_id : Optional [str ] = None ,
335294 ) -> None :
336- """Delete all versions of an artifact using S3 batch delete ."""
295+ """Delete all versions of an artifact."""
337296 versions = await self .list_versions (
338297 app_name = app_name ,
339298 user_id = user_id ,
@@ -352,7 +311,6 @@ async def delete_artifact(
352311 for v in versions
353312 ]
354313 async with self ._client () as s3 :
355- # S3 batch delete supports up to 1000 keys per request
356314 for i in range (0 , len (keys_to_delete ), 1000 ):
357315 batch = keys_to_delete [i : i + 1000 ]
358316 await s3 .delete_objects (
@@ -368,7 +326,7 @@ async def list_versions(
368326 filename : str ,
369327 session_id : Optional [str ] = None ,
370328 ) -> list [int ]:
371- """List all available version numbers for an artifact."""
329+ """List all available versions of an artifact."""
372330 prefix = (
373331 self ._get_blob_prefix (app_name , user_id , session_id , filename ) + "/"
374332 )
@@ -395,7 +353,7 @@ async def list_artifact_versions(
395353 filename : str ,
396354 session_id : Optional [str ] = None ,
397355 ) -> list [ArtifactVersion ]:
398- """List all versions with metadata, using parallel head_object calls ."""
356+ """List all versions with metadata."""
399357 prefix = (
400358 self ._get_blob_prefix (app_name , user_id , session_id , filename ) + "/"
401359 )
@@ -409,7 +367,6 @@ async def list_artifact_versions(
409367 if not page_objects :
410368 continue
411369
412- # Parallelise head_object calls for each page
413370 head_tasks = [
414371 s3 .head_object (Bucket = self .bucket_name , Key = obj ["Key" ])
415372 for obj in page_objects
@@ -445,7 +402,7 @@ async def get_artifact_version(
445402 session_id : Optional [str ] = None ,
446403 version : Optional [int ] = None ,
447404 ) -> Optional [ArtifactVersion ]:
448- """Retrieve metadata for a specific version ( or the latest) ."""
405+ """Retrieve metadata for a specific version, or the latest."""
449406 from botocore .exceptions import ClientError
450407
451408 if version is None :
0 commit comments