Skip to content

Commit 123efc2

Browse files
committed
fix: harden handoff image_urls preprocessing
1 parent d6b5fd1 commit 123efc2

2 files changed

Lines changed: 193 additions & 25 deletions

File tree

astrbot/core/astr_agent_tool_exec.py

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import inspect
33
import json
4+
import os
45
import traceback
56
import typing as T
67
import uuid
@@ -36,9 +37,87 @@
3637
from astrbot.core.provider.entites import ProviderRequest
3738
from astrbot.core.provider.register import llm_tools
3839
from astrbot.core.utils.history_saver import persist_agent_history
40+
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
3941

4042

4143
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
44+
_ALLOWED_IMAGE_EXTENSIONS = {
45+
".png",
46+
".jpg",
47+
".jpeg",
48+
".gif",
49+
".webp",
50+
".bmp",
51+
".tif",
52+
".tiff",
53+
".svg",
54+
".heic",
55+
}
56+
57+
@classmethod
58+
def _is_supported_image_ref(cls, image_ref: str) -> bool:
59+
if not image_ref:
60+
return False
61+
lowered = image_ref.lower()
62+
if lowered.startswith(("http://", "https://", "base64://")):
63+
return True
64+
file_path = image_ref[8:] if lowered.startswith("file:///") else image_ref
65+
ext = os.path.splitext(file_path)[1].lower()
66+
return ext in cls._ALLOWED_IMAGE_EXTENSIONS
67+
68+
@classmethod
69+
async def _prepare_handoff_image_urls(
70+
cls,
71+
run_context: ContextWrapper[AstrAgentContext],
72+
tool_args: dict[str, T.Any],
73+
) -> list[str]:
74+
image_urls = tool_args.get("image_urls")
75+
if image_urls is None:
76+
candidates: list[T.Any] = []
77+
elif isinstance(image_urls, str):
78+
candidates = [image_urls]
79+
else:
80+
try:
81+
candidates = list(image_urls)
82+
except (TypeError, ValueError):
83+
candidates = [image_urls]
84+
85+
normalized = normalize_and_dedupe_strings(candidates)
86+
sanitized = [item for item in normalized if cls._is_supported_image_ref(item)]
87+
dropped_count = len(normalized) - len(sanitized)
88+
if dropped_count > 0:
89+
logger.warning(
90+
"Dropped %d invalid image_urls entries in handoff tool args.",
91+
dropped_count,
92+
)
93+
94+
# Merge current event image attachments so sub-agent behavior matches main-agent flow.
95+
event = getattr(run_context.context, "event", None)
96+
message_obj = getattr(event, "message_obj", None)
97+
message = getattr(message_obj, "message", None)
98+
if message:
99+
for idx, component in enumerate(message):
100+
if not isinstance(component, Image):
101+
continue
102+
try:
103+
path = await component.convert_to_file_path()
104+
if (
105+
path
106+
and cls._is_supported_image_ref(path)
107+
and path not in sanitized
108+
):
109+
sanitized.append(path)
110+
except Exception as e:
111+
logger.error(
112+
"Failed to convert handoff image component at index %d: %s",
113+
idx,
114+
e,
115+
exc_info=True,
116+
)
117+
118+
tool_args["image_urls"] = sanitized
119+
return sanitized
120+
42121
@classmethod
43122
async def execute(cls, tool, run_context, **tool_args):
44123
"""执行函数调用。
@@ -165,29 +244,7 @@ async def _execute_handoff(
165244
**tool_args,
166245
):
167246
input_ = tool_args.get("input")
168-
image_urls = tool_args.get("image_urls")
169-
if image_urls is None:
170-
image_urls = []
171-
elif isinstance(image_urls, str):
172-
image_urls = [image_urls]
173-
else:
174-
try:
175-
image_urls = list(image_urls)
176-
except (TypeError, ValueError):
177-
image_urls = [image_urls]
178-
179-
# 获取当前事件中的图片
180-
event = run_context.context.event
181-
if event.message_obj and event.message_obj.message:
182-
for component in event.message_obj.message:
183-
if isinstance(component, Image):
184-
try:
185-
# 调用组件的 convert_to_file_path 异步方法
186-
path = await component.convert_to_file_path()
187-
if path and path not in image_urls:
188-
image_urls.append(path)
189-
except Exception as e:
190-
logger.error(f"转换图片失败: {e}")
247+
image_urls = await cls._prepare_handoff_image_urls(run_context, tool_args)
191248

192249
# Build handoff toolset from registered tools plus runtime computer tools.
193250
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
@@ -286,8 +343,12 @@ async def _do_handoff_background(
286343
) -> None:
287344
"""Run the subagent handoff and, on completion, wake the main agent."""
288345
result_text = ""
346+
prepared_tool_args = dict(tool_args)
289347
try:
290-
async for r in cls._execute_handoff(tool, run_context, **tool_args):
348+
await cls._prepare_handoff_image_urls(run_context, prepared_tool_args)
349+
async for r in cls._execute_handoff(
350+
tool, run_context, **prepared_tool_args
351+
):
291352
if isinstance(r, mcp.types.CallToolResult):
292353
for content in r.content:
293354
if isinstance(content, mcp.types.TextContent):
@@ -304,7 +365,7 @@ async def _do_handoff_background(
304365
task_id=task_id,
305366
tool_name=tool.name,
306367
result_text=result_text,
307-
tool_args=tool_args,
368+
tool_args=prepared_tool_args,
308369
note=(
309370
event.get_extra("background_note")
310371
or f"Background task for subagent '{tool.agent.name}' finished."
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from types import SimpleNamespace
2+
3+
import mcp
4+
import pytest
5+
6+
from astrbot.core.agent.run_context import ContextWrapper
7+
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
8+
from astrbot.core.message.components import Image
9+
10+
11+
class _DummyEvent:
12+
def __init__(self, message_components: list[object] | None = None) -> None:
13+
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
14+
self.message_obj = SimpleNamespace(message=message_components or [])
15+
16+
def get_extra(self, _key: str):
17+
return None
18+
19+
20+
class _DummyTool:
21+
def __init__(self) -> None:
22+
self.name = "transfer_to_subagent"
23+
self.agent = SimpleNamespace(name="subagent")
24+
25+
26+
def _build_run_context(message_components: list[object] | None = None):
27+
event = _DummyEvent(message_components=message_components)
28+
ctx = SimpleNamespace(event=event, context=SimpleNamespace())
29+
return ContextWrapper(context=ctx)
30+
31+
32+
@pytest.mark.asyncio
33+
async def test_prepare_handoff_image_urls_normalizes_filters_and_appends_event_image(
34+
monkeypatch: pytest.MonkeyPatch,
35+
):
36+
async def _fake_convert_to_file_path(self):
37+
return "/tmp/event_image.png"
38+
39+
monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path)
40+
41+
run_context = _build_run_context([Image(file="file:///tmp/original.png")])
42+
tool_args = {
43+
"image_urls": (
44+
" https://example.com/a.png ",
45+
"/tmp/not_an_image.txt",
46+
"/tmp/local.webp",
47+
123,
48+
)
49+
}
50+
51+
image_urls = await FunctionToolExecutor._prepare_handoff_image_urls(
52+
run_context,
53+
tool_args,
54+
)
55+
56+
assert image_urls == [
57+
"https://example.com/a.png",
58+
"/tmp/local.webp",
59+
"/tmp/event_image.png",
60+
]
61+
assert tool_args["image_urls"] == image_urls
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_do_handoff_background_reports_prepared_image_urls(
66+
monkeypatch: pytest.MonkeyPatch,
67+
):
68+
captured: dict = {}
69+
70+
async def _fake_prepare(cls, run_context, tool_args):
71+
tool_args["image_urls"] = ["prepared://image.png"]
72+
return tool_args["image_urls"]
73+
74+
async def _fake_execute_handoff(cls, tool, run_context, **tool_args):
75+
yield mcp.types.CallToolResult(
76+
content=[mcp.types.TextContent(type="text", text="ok")]
77+
)
78+
79+
async def _fake_wake(cls, run_context, **kwargs):
80+
captured.update(kwargs)
81+
82+
monkeypatch.setattr(
83+
FunctionToolExecutor,
84+
"_prepare_handoff_image_urls",
85+
classmethod(_fake_prepare),
86+
)
87+
monkeypatch.setattr(
88+
FunctionToolExecutor,
89+
"_execute_handoff",
90+
classmethod(_fake_execute_handoff),
91+
)
92+
monkeypatch.setattr(
93+
FunctionToolExecutor,
94+
"_wake_main_agent_for_background_result",
95+
classmethod(_fake_wake),
96+
)
97+
98+
run_context = _build_run_context()
99+
await FunctionToolExecutor._do_handoff_background(
100+
tool=_DummyTool(),
101+
run_context=run_context,
102+
task_id="task-id",
103+
input="hello",
104+
image_urls="https://example.com/raw.png",
105+
)
106+
107+
assert captured["tool_args"]["image_urls"] == ["prepared://image.png"]

0 commit comments

Comments
 (0)