Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 77 additions & 39 deletions src/google/adk/tools/load_artifacts_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import ast
import base64
import binascii
import json
Expand Down Expand Up @@ -46,6 +47,7 @@
from .tool_context import ToolContext

logger = logging.getLogger('google_adk.' + __name__)
_LOAD_ARTIFACTS_TEXT_MARKER = '`load_artifacts` tool returned result:'


def _normalize_mime_type(mime_type: str | None) -> str | None:
Expand Down Expand Up @@ -121,6 +123,47 @@ def _as_safe_part_for_llm(
)


def _artifact_names_from_response(response: Any) -> list[str]:
if not isinstance(response, dict):
return []

artifact_names = response.get('artifact_names', [])
if isinstance(artifact_names, str):
return [artifact_names]
if not isinstance(artifact_names, list):
return []
return [name for name in artifact_names if isinstance(name, str)]


def _artifact_names_from_text_response(text: str | None) -> list[str]:
if not text or _LOAD_ARTIFACTS_TEXT_MARKER not in text:
return []

payload = text.split(_LOAD_ARTIFACTS_TEXT_MARKER, 1)[1].strip()
try:
response = ast.literal_eval(payload)
except (SyntaxError, ValueError) as exc:
logger.debug('Could not parse load_artifacts text response: %s', exc)
return []

return _artifact_names_from_response(response)


def _requested_artifact_names(content: types.Content) -> list[str]:
artifact_names: list[str] = []
for part in content.parts or []:
function_response = part.function_response
if function_response and function_response.name == 'load_artifacts':
artifact_names.extend(
_artifact_names_from_response(function_response.response or {})
)
continue

artifact_names.extend(_artifact_names_from_text_response(part.text))

return artifact_names


class LoadArtifactsTool(BaseTool):
"""A tool that loads the artifacts and adds them to the session."""

Expand Down Expand Up @@ -210,46 +253,41 @@ async def _append_artifacts_to_llm_request(
# Attach the content of the artifacts if the model requests them.
# This only adds the content to the model request, instead of the session.
if llm_request.contents and llm_request.contents[-1].parts:
function_response = llm_request.contents[-1].parts[0].function_response
if function_response and function_response.name == 'load_artifacts':
response = function_response.response or {}
artifact_names = response.get('artifact_names', [])
for artifact_name in artifact_names:
# Try session-scoped first (default behavior)
artifact = await tool_context.load_artifact(artifact_name)

# If not found and name doesn't already have user: prefix,
# try cross-session artifacts with user: prefix
if artifact is None and not artifact_name.startswith('user:'):
prefixed_name = f'user:{artifact_name}'
artifact = await tool_context.load_artifact(prefixed_name)

if artifact is None:
logger.warning('Artifact "%s" not found, skipping', artifact_name)
continue

artifact_part = _as_safe_part_for_llm(artifact, artifact_name)
if artifact_part is not artifact:
mime_type = (
artifact.inline_data.mime_type if artifact.inline_data else None
)
logger.debug(
'Converted artifact "%s" (mime_type=%s) to text Part',
artifact_name,
mime_type,
)

llm_request.contents.append(
types.Content(
role='user',
parts=[
types.Part.from_text(
text=f'Artifact {artifact_name} is:'
),
artifact_part,
],
)
artifact_names = _requested_artifact_names(llm_request.contents[-1])
for artifact_name in artifact_names:
# Try session-scoped first (default behavior)
artifact = await tool_context.load_artifact(artifact_name)

# If not found and name doesn't already have user: prefix,
# try cross-session artifacts with user: prefix
if artifact is None and not artifact_name.startswith('user:'):
prefixed_name = f'user:{artifact_name}'
artifact = await tool_context.load_artifact(prefixed_name)

if artifact is None:
logger.warning('Artifact "%s" not found, skipping', artifact_name)
continue

artifact_part = _as_safe_part_for_llm(artifact, artifact_name)
if artifact_part is not artifact:
mime_type = (
artifact.inline_data.mime_type if artifact.inline_data else None
)
logger.debug(
'Converted artifact "%s" (mime_type=%s) to text Part',
artifact_name,
mime_type,
)

llm_request.contents.append(
types.Content(
role='user',
parts=[
types.Part.from_text(text=f'Artifact {artifact_name} is:'),
artifact_part,
],
)
)


load_artifacts_tool = LoadArtifactsTool()
69 changes: 69 additions & 0 deletions tests/unittests/tools/test_load_artifacts_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,75 @@ async def test_load_artifacts_keeps_supported_mime_types():
assert artifact_part.inline_data.mime_type == 'application/pdf'


@mark.asyncio
async def test_load_artifacts_reads_workflow_text_response():
"""Workflow context can stringify tool responses from other nodes."""
artifact_name = 'invoice.txt'
artifact = types.Part.from_text(text='invoice total: 42')

tool_context = _StubToolContext({artifact_name: artifact})
llm_request = LlmRequest(
contents=[
types.Content(
role='user',
parts=[
types.Part.from_text(text='For context:'),
types.Part.from_text(
text=(
'[workflow_node] `load_artifacts` tool returned'
" result: {'artifact_names': ['invoice.txt'],"
" 'status': 'ok'}"
)
),
],
)
]
)

await load_artifacts_tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)

assert llm_request.contents[-1].parts[0].text == (
f'Artifact {artifact_name} is:'
)
assert llm_request.contents[-1].parts[1].text == 'invoice total: 42'


@mark.asyncio
async def test_load_artifacts_checks_all_function_response_parts():
"""The load_artifacts response may not be the first part in a turn."""
artifact_name = 'notes.txt'
artifact = types.Part.from_text(text='important notes')

tool_context = _StubToolContext({artifact_name: artifact})
llm_request = LlmRequest(
contents=[
types.Content(
role='user',
parts=[
types.Part.from_text(text='Done.'),
types.Part(
function_response=types.FunctionResponse(
name='load_artifacts',
response={'artifact_names': [artifact_name]},
)
),
],
)
]
)

await load_artifacts_tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)

assert llm_request.contents[-1].parts[0].text == (
f'Artifact {artifact_name} is:'
)
assert llm_request.contents[-1].parts[1].text == 'important notes'


def test_maybe_base64_to_bytes_decodes_standard_base64():
"""Standard base64 encoded strings are decoded correctly."""
original = b'hello world'
Expand Down