|
22 | 22 | from typing import Mapping |
23 | 23 | from typing import Optional |
24 | 24 |
|
| 25 | +from google.genai import types |
25 | 26 | from typing_extensions import override |
26 | 27 |
|
| 28 | +from ...artifacts.base_artifact_service import ArtifactVersion |
27 | 29 | from ...artifacts.base_artifact_service import BaseArtifactService |
28 | 30 | from ...artifacts.file_artifact_service import FileArtifactService |
29 | 31 | from ...events.event import Event |
|
37 | 39 | logger = logging.getLogger("google_adk." + __name__) |
38 | 40 |
|
39 | 41 | _BUILT_IN_SESSION_SERVICE_KEY = "__adk_built_in_session_service__" |
| 42 | +_BUILT_IN_ARTIFACT_SERVICE_KEY = "__adk_built_in_artifact_service__" |
40 | 43 |
|
41 | 44 |
|
42 | 45 | def create_local_database_session_service( |
@@ -95,16 +98,31 @@ def create_local_session_service( |
95 | 98 |
|
96 | 99 |
|
97 | 100 | def create_local_artifact_service( |
98 | | - *, base_dir: Path | str |
| 101 | + *, |
| 102 | + base_dir: Path | str, |
| 103 | + per_agent: bool = False, |
| 104 | + app_name_to_dir: Optional[Mapping[str, str]] = None, |
99 | 105 | ) -> BaseArtifactService: |
100 | | - """Creates a file-backed artifact service rooted in `.adk/artifacts`. |
| 106 | + """Creates a file-backed artifact service that persists data in `.adk/artifacts` folders. |
101 | 107 |
|
102 | 108 | Args: |
103 | 109 | base_dir: Directory whose `.adk` folder will store artifacts. |
| 110 | + per_agent: If True, creates a PerAgentFileArtifactService that stores |
| 111 | + artifacts in each agent's `.adk/artifacts` folder. If False, creates a |
| 112 | + single FileArtifactService at base_dir/.adk/artifacts. |
| 113 | + app_name_to_dir: Optional mapping from logical app name to on-disk agent |
| 114 | + folder name. Only used when per_agent is True; defaults to identity. |
104 | 115 |
|
105 | 116 | Returns: |
106 | | - A `FileArtifactService` scoped to the derived root directory. |
| 117 | + A `BaseArtifactService` backed by the local filesystem. |
107 | 118 | """ |
| 119 | + if per_agent: |
| 120 | + logger.info("Using per-agent artifact storage rooted at %s", base_dir) |
| 121 | + return PerAgentFileArtifactService( |
| 122 | + agents_root=base_dir, |
| 123 | + app_name_to_dir=app_name_to_dir, |
| 124 | + ) |
| 125 | + |
108 | 126 | manager = DotAdkFolder(base_dir) |
109 | 127 | artifact_root = manager.artifacts_dir |
110 | 128 | artifact_root.mkdir(parents=True, exist_ok=True) |
@@ -217,3 +235,242 @@ async def get_user_state( |
217 | 235 | async def append_event(self, session: Session, event: Event) -> Event: |
218 | 236 | service = await self._get_service(session.app_name) |
219 | 237 | return await service.append_event(session, event) |
| 238 | + |
| 239 | + |
| 240 | +class PerAgentFileArtifactService(BaseArtifactService): |
| 241 | + """Routes artifact storage to per-agent `.adk/artifacts` folders.""" |
| 242 | + |
| 243 | + def __init__( |
| 244 | + self, |
| 245 | + *, |
| 246 | + agents_root: Path | str, |
| 247 | + app_name_to_dir: Optional[Mapping[str, str]] = None, |
| 248 | + ): |
| 249 | + self._agents_root = Path(agents_root).resolve() |
| 250 | + self._app_name_to_dir = dict(app_name_to_dir or {}) |
| 251 | + self._services: dict[str, BaseArtifactService] = {} |
| 252 | + self._legacy_service: Optional[BaseArtifactService] = None |
| 253 | + self._service_lock = asyncio.Lock() |
| 254 | + |
| 255 | + async def _get_service(self, app_name: str) -> BaseArtifactService: |
| 256 | + async with self._service_lock: |
| 257 | + if app_name.startswith("__"): |
| 258 | + storage_key = _BUILT_IN_ARTIFACT_SERVICE_KEY |
| 259 | + base_dir = self._agents_root |
| 260 | + else: |
| 261 | + storage_key = self._app_name_to_dir.get(app_name, app_name) |
| 262 | + folder = dot_adk_folder_for_agent( |
| 263 | + agents_root=self._agents_root, app_name=storage_key |
| 264 | + ) |
| 265 | + base_dir = folder.agent_dir |
| 266 | + |
| 267 | + service = self._services.get(storage_key) |
| 268 | + if service is not None: |
| 269 | + return service |
| 270 | + |
| 271 | + service = create_local_artifact_service(base_dir=base_dir) |
| 272 | + self._services[storage_key] = service |
| 273 | + return service |
| 274 | + |
| 275 | + async def _get_legacy_service( |
| 276 | + self, app_name: str |
| 277 | + ) -> Optional[BaseArtifactService]: |
| 278 | + """Returns a reader for the pre-per-agent shared `.adk/artifacts` root. |
| 279 | +
|
| 280 | + Returns None for built-in agents (which already use that root) and when |
| 281 | + no legacy directory exists, so reads fall back only when there is legacy |
| 282 | + data to find. Never creates the legacy directory. |
| 283 | + """ |
| 284 | + if app_name.startswith("__"): |
| 285 | + return None |
| 286 | + if self._legacy_service is not None: |
| 287 | + return self._legacy_service |
| 288 | + legacy_dir = DotAdkFolder(self._agents_root).artifacts_dir |
| 289 | + if not legacy_dir.exists(): |
| 290 | + return None |
| 291 | + async with self._service_lock: |
| 292 | + if self._legacy_service is None: |
| 293 | + self._legacy_service = FileArtifactService(root_dir=legacy_dir) |
| 294 | + return self._legacy_service |
| 295 | + |
| 296 | + @override |
| 297 | + async def save_artifact( |
| 298 | + self, |
| 299 | + *, |
| 300 | + app_name: str, |
| 301 | + user_id: str, |
| 302 | + filename: str, |
| 303 | + artifact: types.Part | dict[str, Any], |
| 304 | + session_id: Optional[str] = None, |
| 305 | + custom_metadata: Optional[dict[str, Any]] = None, |
| 306 | + ) -> int: |
| 307 | + service = await self._get_service(app_name) |
| 308 | + return await service.save_artifact( |
| 309 | + app_name=app_name, |
| 310 | + user_id=user_id, |
| 311 | + filename=filename, |
| 312 | + artifact=artifact, |
| 313 | + session_id=session_id, |
| 314 | + custom_metadata=custom_metadata, |
| 315 | + ) |
| 316 | + |
| 317 | + @override |
| 318 | + async def load_artifact( |
| 319 | + self, |
| 320 | + *, |
| 321 | + app_name: str, |
| 322 | + user_id: str, |
| 323 | + filename: str, |
| 324 | + session_id: Optional[str] = None, |
| 325 | + version: Optional[int] = None, |
| 326 | + ) -> Optional[types.Part]: |
| 327 | + service = await self._get_service(app_name) |
| 328 | + result = await service.load_artifact( |
| 329 | + app_name=app_name, |
| 330 | + user_id=user_id, |
| 331 | + filename=filename, |
| 332 | + session_id=session_id, |
| 333 | + version=version, |
| 334 | + ) |
| 335 | + if result is not None: |
| 336 | + return result |
| 337 | + legacy = await self._get_legacy_service(app_name) |
| 338 | + if legacy is None: |
| 339 | + return None |
| 340 | + return await legacy.load_artifact( |
| 341 | + app_name=app_name, |
| 342 | + user_id=user_id, |
| 343 | + filename=filename, |
| 344 | + session_id=session_id, |
| 345 | + version=version, |
| 346 | + ) |
| 347 | + |
| 348 | + @override |
| 349 | + async def list_artifact_keys( |
| 350 | + self, *, app_name: str, user_id: str, session_id: Optional[str] = None |
| 351 | + ) -> list[str]: |
| 352 | + service = await self._get_service(app_name) |
| 353 | + keys = await service.list_artifact_keys( |
| 354 | + app_name=app_name, user_id=user_id, session_id=session_id |
| 355 | + ) |
| 356 | + legacy = await self._get_legacy_service(app_name) |
| 357 | + if legacy is None: |
| 358 | + return keys |
| 359 | + legacy_keys = await legacy.list_artifact_keys( |
| 360 | + app_name=app_name, user_id=user_id, session_id=session_id |
| 361 | + ) |
| 362 | + return sorted(set(keys) | set(legacy_keys)) |
| 363 | + |
| 364 | + @override |
| 365 | + async def delete_artifact( |
| 366 | + self, |
| 367 | + *, |
| 368 | + app_name: str, |
| 369 | + user_id: str, |
| 370 | + filename: str, |
| 371 | + session_id: Optional[str] = None, |
| 372 | + ) -> None: |
| 373 | + service = await self._get_service(app_name) |
| 374 | + await service.delete_artifact( |
| 375 | + app_name=app_name, |
| 376 | + user_id=user_id, |
| 377 | + filename=filename, |
| 378 | + session_id=session_id, |
| 379 | + ) |
| 380 | + # Also delete any legacy copy so a deleted artifact can't reappear via the |
| 381 | + # read fallback. |
| 382 | + legacy = await self._get_legacy_service(app_name) |
| 383 | + if legacy is not None: |
| 384 | + await legacy.delete_artifact( |
| 385 | + app_name=app_name, |
| 386 | + user_id=user_id, |
| 387 | + filename=filename, |
| 388 | + session_id=session_id, |
| 389 | + ) |
| 390 | + |
| 391 | + @override |
| 392 | + async def list_versions( |
| 393 | + self, |
| 394 | + *, |
| 395 | + app_name: str, |
| 396 | + user_id: str, |
| 397 | + filename: str, |
| 398 | + session_id: Optional[str] = None, |
| 399 | + ) -> list[int]: |
| 400 | + service = await self._get_service(app_name) |
| 401 | + versions = await service.list_versions( |
| 402 | + app_name=app_name, |
| 403 | + user_id=user_id, |
| 404 | + filename=filename, |
| 405 | + session_id=session_id, |
| 406 | + ) |
| 407 | + if versions: |
| 408 | + return versions |
| 409 | + legacy = await self._get_legacy_service(app_name) |
| 410 | + if legacy is None: |
| 411 | + return versions |
| 412 | + return await legacy.list_versions( |
| 413 | + app_name=app_name, |
| 414 | + user_id=user_id, |
| 415 | + filename=filename, |
| 416 | + session_id=session_id, |
| 417 | + ) |
| 418 | + |
| 419 | + @override |
| 420 | + async def list_artifact_versions( |
| 421 | + self, |
| 422 | + *, |
| 423 | + app_name: str, |
| 424 | + user_id: str, |
| 425 | + filename: str, |
| 426 | + session_id: Optional[str] = None, |
| 427 | + ) -> list[ArtifactVersion]: |
| 428 | + service = await self._get_service(app_name) |
| 429 | + versions = await service.list_artifact_versions( |
| 430 | + app_name=app_name, |
| 431 | + user_id=user_id, |
| 432 | + filename=filename, |
| 433 | + session_id=session_id, |
| 434 | + ) |
| 435 | + if versions: |
| 436 | + return versions |
| 437 | + legacy = await self._get_legacy_service(app_name) |
| 438 | + if legacy is None: |
| 439 | + return versions |
| 440 | + return await legacy.list_artifact_versions( |
| 441 | + app_name=app_name, |
| 442 | + user_id=user_id, |
| 443 | + filename=filename, |
| 444 | + session_id=session_id, |
| 445 | + ) |
| 446 | + |
| 447 | + @override |
| 448 | + async def get_artifact_version( |
| 449 | + self, |
| 450 | + *, |
| 451 | + app_name: str, |
| 452 | + user_id: str, |
| 453 | + filename: str, |
| 454 | + session_id: Optional[str] = None, |
| 455 | + version: Optional[int] = None, |
| 456 | + ) -> Optional[ArtifactVersion]: |
| 457 | + service = await self._get_service(app_name) |
| 458 | + result = await service.get_artifact_version( |
| 459 | + app_name=app_name, |
| 460 | + user_id=user_id, |
| 461 | + filename=filename, |
| 462 | + session_id=session_id, |
| 463 | + version=version, |
| 464 | + ) |
| 465 | + if result is not None: |
| 466 | + return result |
| 467 | + legacy = await self._get_legacy_service(app_name) |
| 468 | + if legacy is None: |
| 469 | + return None |
| 470 | + return await legacy.get_artifact_version( |
| 471 | + app_name=app_name, |
| 472 | + user_id=user_id, |
| 473 | + filename=filename, |
| 474 | + session_id=session_id, |
| 475 | + version=version, |
| 476 | + ) |
0 commit comments