Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions astrbot/core/agent/handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(
def default_parameters(self) -> dict:
return {
"type": "object",
"required": ["input"],
"additionalProperties": False,
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment on lines +42 to +43
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Making input required and disallowing additional properties conflicts with the background_task argument handling.

With input required and additionalProperties=False, any background_task key in tool_args is now invalid JSON Schema, even though FunctionToolExecutor.execute still reads it and _build_handoff_error_result shows it in the example. Either add background_task (with the correct type) to properties or relax additionalProperties to allow it; if it’s no longer supported, remove it from both the error example and the tool_args handling to keep schema and behavior aligned.

"properties": {
"input": {
"type": "string",
Expand Down
134 changes: 129 additions & 5 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,115 @@


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
def _build_handoff_error_result(
cls,
*,
tool_name: str,
error_type: str,
fix_hint: str,
action_hint: str,
) -> mcp.types.CallToolResult:
guidance = (
"[handoff CALL FAILED - IMMEDIATE RETRY REQUIRED]\n"
f"error_type: {error_type}\n"
f"fix: {fix_hint}\n"
f"action: {action_hint}\n"
"example:\n"
"{\n"
' "input": "Summarize the user request, constraints, and expected output.",\n'
' "background_task": false\n'
"}"
)
return mcp.types.CallToolResult(
content=[
mcp.types.TextContent(
type="text",
text=f"error: {tool_name} rejected invalid handoff request.\n{guidance}",
)
]
)

@classmethod
def _parse_background_task_arg(
cls,
tool_name: str,
value: T.Any,
) -> tuple[bool, mcp.types.CallToolResult | None]:
if value is None:
return False, None
if isinstance(value, bool):
return value, None
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "1", "yes", "on"}:
return True, None
if normalized in {"false", "0", "no", "off", ""}:
return False, None

return False, cls._build_handoff_error_result(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
tool_name=tool_name,
error_type="invalid_background_task",
fix_hint=(
"`background_task` must be a boolean (`true` or `false`) or a string "
"equivalent such as `\"1\"`/`\"0\"`, `\"yes\"`/`\"no\"`, or "
"`\"on\"`/`\"off\"`."
),
action_hint=(
"Retry the same handoff with `background_task` set to a boolean or one "
"of the supported string equivalents (`\"true\"`, `\"false\"`, "
"`\"1\"`, `\"0\"`, `\"yes\"`, `\"no\"`, `\"on\"`, `\"off\"`)."
),
)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

@classmethod
def _normalize_handoff_input(
cls,
tool_name: str,
input_value: T.Any,
) -> tuple[str | None, mcp.types.CallToolResult | None]:
if not isinstance(input_value, str) or not input_value.strip():
return None, cls._build_handoff_error_result(
tool_name=tool_name,
error_type="missing_or_empty_input",
fix_hint=(
"Provide a non-empty `input` string that clearly describes the delegated task."
),
action_hint=(
"Retry now with a concise task statement in `input`."
),
)
return input_value.strip(), None

@classmethod
async def _resolve_handoff_provider_id(
cls,
tool: HandoffTool,
*,
ctx: T.Any,
umo: str,
) -> str:
configured_provider_id = str(getattr(tool, "provider_id", "") or "").strip()
if not configured_provider_id:
return await ctx.get_current_chat_provider_id(umo)

provider_mgr = getattr(ctx, "provider_manager", None)
if provider_mgr is None or not hasattr(provider_mgr, "get_provider_by_id"):
return configured_provider_id

provider_inst = await provider_mgr.get_provider_by_id(configured_provider_id)
if provider_inst is not None:
return configured_provider_id

fallback_provider_id = await ctx.get_current_chat_provider_id(umo)
logger.warning(
"Subagent %s configured provider `%s` not found, fallback to `%s`.",
tool.name,
configured_provider_id,
fallback_provider_id,
)
return fallback_provider_id

@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
if image_urls_raw is None:
Expand Down Expand Up @@ -130,7 +239,13 @@ async def execute(cls, tool, run_context, **tool_args):

"""
if isinstance(tool, HandoffTool):
is_bg = tool_args.pop("background_task", False)
is_bg, bg_error = cls._parse_background_task_arg(
tool.name,
tool_args.pop("background_task", None),
)
if bg_error is not None:
yield bg_error
return
if is_bg:
async for r in cls._execute_handoff_background(
tool, run_context, **tool_args
Expand Down Expand Up @@ -245,7 +360,14 @@ async def _execute_handoff(
**tool_args: T.Any,
):
tool_args = dict(tool_args)
input_ = tool_args.get("input")
input_, input_error = cls._normalize_handoff_input(
tool.name,
tool_args.get("input"),
)
if input_error is not None:
yield input_error
return
tool_args["input"] = input_
if image_urls_prepared:
prepared_image_urls = tool_args.get("image_urls")
if isinstance(prepared_image_urls, list):
Expand All @@ -272,9 +394,11 @@ async def _execute_handoff(

# Use per-subagent provider override if configured; otherwise fall back
# to the current/default provider resolution.
prov_id = getattr(
tool, "provider_id", None
) or await ctx.get_current_chat_provider_id(umo)
prov_id = await cls._resolve_handoff_provider_id(
tool,
ctx=ctx,
umo=umo,
)

# prepare begin dialogs
contexts = None
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/test_astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,96 @@ async def _fake_convert_to_file_path(self):
)

assert image_urls == []


@pytest.mark.asyncio
async def test_execute_handoff_rejects_empty_input():
async def _fake_get_current_chat_provider_id(_umo):
return "provider-id"

async def _fake_tool_loop_agent(**_kwargs):
return SimpleNamespace(completion_text="ok")

context = SimpleNamespace(
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
Comment on lines +348 to +357
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen empty-input test by asserting the handoff agent is never invoked

To fully validate the early-rejection behavior, also assert that _fake_tool_loop_agent is never invoked (e.g., via a flag or counter). This will confirm the error is raised before any downstream handoff execution occurs for empty input.

Suggested implementation:

@pytest.mark.asyncio
async def test_execute_handoff_rejects_empty_input():
    tool_loop_called = False

    async def _fake_get_current_chat_provider_id(_umo):
        return "provider-id"

    async def _fake_tool_loop_agent(**_kwargs):
        nonlocal tool_loop_called
        tool_loop_called = True
        return SimpleNamespace(completion_text="ok")

    context = SimpleNamespace(
        get_current_chat_provider_id=_fake_get_current_chat_provider_id,
        tool_loop_agent=_fake_tool_loop_agent,

To fully implement the strengthened test, you should also:

  1. After the code that triggers the empty-input rejection (likely the call to the handoff execution function and the corresponding pytest.raises block), add:
    assert tool_loop_called is False
  2. Ensure this assertion is placed at the end of test_execute_handoff_rejects_empty_input, after all other assertions that verify the rejection behavior.

tool_loop_agent=_fake_tool_loop_agent,
get_config=lambda **_kwargs: {"provider_settings": {}},
)
event = _DummyEvent([])
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
tool = SimpleNamespace(
name="transfer_to_subagent",
provider_id=None,
agent=SimpleNamespace(
name="subagent",
tools=[],
instructions="subagent-instructions",
begin_dialogs=[],
run_hooks=None,
),
)

results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input=" ",
image_urls=[],
):
results.append(result)

assert len(results) == 1
assert isinstance(results[0], mcp.types.CallToolResult)
text_content = results[0].content[0]
assert isinstance(text_content, mcp.types.TextContent)
assert "missing_or_empty_input" in text_content.text


@pytest.mark.asyncio
async def test_execute_handoff_falls_back_to_current_provider_when_configured_missing():
captured: dict = {}

class _DummyProviderManager:
async def get_provider_by_id(self, _provider_id: str):
return None

async def _fake_get_current_chat_provider_id(_umo):
return "fallback-provider"

async def _fake_tool_loop_agent(**kwargs):
captured.update(kwargs)
return SimpleNamespace(completion_text="ok")

context = SimpleNamespace(
provider_manager=_DummyProviderManager(),
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
tool_loop_agent=_fake_tool_loop_agent,
get_config=lambda **_kwargs: {"provider_settings": {}},
)
event = _DummyEvent([])
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
tool = SimpleNamespace(
name="transfer_to_subagent",
provider_id="missing-provider-id",
agent=SimpleNamespace(
name="subagent",
tools=[],
instructions="subagent-instructions",
begin_dialogs=[],
run_hooks=None,
),
)

results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input="hello",
image_urls=[],
):
results.append(result)

assert len(results) == 1
assert captured["chat_provider_id"] == "fallback-provider"
104 changes: 104 additions & 0 deletions tests/unit/test_handoff_background_task_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from types import SimpleNamespace

import mcp
import pytest

from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor


@pytest.mark.parametrize(
("value", "expected_bool", "expect_error"),
[
(True, True, False),
("true", True, False),
("1", True, False),
("yes", True, False),
("on", True, False),
(" TRUE ", True, False),
(False, False, False),
Comment on lines +24 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Background task parsing is well-covered at the helper level; consider adding an integration-level test for invalid background_task via execute to ensure the early-yield behavior is preserved.

The current parametrized test covers _parse_background_task_arg, but not how FunctionToolExecutor.execute behaves when background_task is invalid. Please add a small async test that calls execute (or _execute_handoff, if that’s the public entry) with an invalid background_task and asserts that it yields a single CallToolResult error and returns without invoking the handoff or any downstream calls.

Suggested implementation:

import mcp
import pytest
from unittest.mock import AsyncMock

from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor


@pytest.mark.asyncio
async def test_execute_invalid_background_task_early_error(monkeypatch):
    """Integration-level test: invalid background_task should yield a single error and not hand off."""
    # Arrange: create an executor instance
    executor = FunctionToolExecutor()

    # If execute delegates into a separate handoff method, guard it so the test fails if it is invoked.
    # Adjust the attribute name if the handoff entrypoint differs (e.g. `handoff`, `_handoff`, etc.).
    if hasattr(executor, "_execute_handoff"):
        monkeypatch.setattr(
            executor,
            "_execute_handoff",
            AsyncMock(side_effect=AssertionError("_execute_handoff should not be called for invalid background_task")),
        )

    # Act: call execute with an invalid background_task value that should fail parsing
    results = [
        result async for result in executor.execute(
            tool_name="some_tool",
            arguments={},
            background_task="not-a-bool",  # intentionally invalid
        )
    ]

    # Assert: exactly one CallToolResult is yielded, it is an error, and no downstream calls are made.
    assert len(results) == 1
    assert isinstance(results[0], mcp.CallToolResult)
    assert getattr(results[0], "is_error", True), "Expected the single CallToolResult to represent an error"


@pytest.mark.parametrize(

The new integration-level test assumes:

  1. FunctionToolExecutor() can be constructed without arguments. If the real initializer requires parameters, update the test's executor = FunctionToolExecutor() line to pass appropriate arguments or use a helper/factory already used elsewhere in your tests.
  2. FunctionToolExecutor.execute is an async generator that accepts tool_name, arguments, and background_task keyword arguments. If the signature differs, align the call in the test with the actual method signature.
  3. The early-yield error path does not invoke _execute_handoff. If the actual handoff entrypoint has a different name (e.g., handoff, _handoff, _execute_tool_handoff), update the hasattr / monkeypatch.setattr to reference the correct attribute, or remove that block if there is no such method.
  4. mcp.CallToolResult has an is_error attribute or similar. If the error indication is expressed differently (e.g., result.error is not None, result.type == "error", etc.), adjust the final assertion to match your concrete API.

Comment on lines +24 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding dedicated tests for _parse_background_task_arg integration via execute with valid values

The parameterized helper tests cover value parsing well, but FunctionToolExecutor.execute is only exercised for the invalid string case. Please add execute-level tests where:

  • A truthy background value (e.g. True or "true") results in exactly one _execute_handoff_background call, zero _execute_handoff calls.
  • A falsey background value (e.g. False, "false", or "0") results in exactly one _execute_handoff call, zero _execute_handoff_background calls.

These can follow the pattern of test_execute_invalid_background_task_early_error, using monkeypatching and call_count assertions, and verifying that no error CallToolResult is returned.

Suggested implementation:

    if expect_error:
        assert error is not None
        assert isinstance(error, mcp.types.CallToolResult)
        text_content = error.content[0]
        assert isinstance(text_content, mcp.types.TextContent)
        assert "invalid_background_task" in text_content.text


@pytest.mark.parametrize("background_value", [True, "true"])
def test_execute_truthy_background_uses_background_handoff(monkeypatch, background_value):
    handoff_call_count = 0
    background_call_count = 0

    def fake_execute_handoff(*args, **kwargs):
        nonlocal handoff_call_count
        handoff_call_count += 1
        # Return a non-error CallToolResult shape; adjust as needed for the real API.
        return mcp.types.CallToolResult(
            content=[mcp.types.TextContent(type="text", text="foreground handoff")],
        )

    def fake_execute_handoff_background(*args, **kwargs):
        nonlocal background_call_count
        background_call_count += 1
        return mcp.types.CallToolResult(
            content=[mcp.types.TextContent(type="text", text="background handoff")],
        )

    monkeypatch.setattr(
        FunctionToolExecutor,
        "_execute_handoff",
        fake_execute_handoff,
    )
    monkeypatch.setattr(
        FunctionToolExecutor,
        "_execute_handoff_background",
        fake_execute_handoff_background,
    )

    # Construct minimal arguments following the pattern of other tests in this module.
    # These may need to be adjusted to match the actual FunctionToolExecutor.execute signature.
    agent = Agent()
    tool = HandoffTool(name="transfer_to_subagent")
    context = ContextWrapper(agent=agent)
    arguments = {"background_task": background_value}

    result = FunctionToolExecutor.execute(
        agent=agent,
        tool=tool,
        context=context,
        arguments=arguments,
    )

    assert handoff_call_count == 0
    assert background_call_count == 1

    # Ensure no "invalid_background_task" error CallToolResult was returned.
    results = result if isinstance(result, (list, tuple)) else [result]
    for res in results:
        if isinstance(res, mcp.types.CallToolResult) and res.content:
            first = res.content[0]
            if isinstance(first, mcp.types.TextContent):
                assert "invalid_background_task" not in first.text


@pytest.mark.parametrize("background_value", [False, "false", "0"])
def test_execute_falsey_background_uses_foreground_handoff(monkeypatch, background_value):
    handoff_call_count = 0
    background_call_count = 0

    def fake_execute_handoff(*args, **kwargs):
        nonlocal handoff_call_count
        handoff_call_count += 1
        return mcp.types.CallToolResult(
            content=[mcp.types.TextContent(type="text", text="foreground handoff")],
        )

    def fake_execute_handoff_background(*args, **kwargs):
        nonlocal background_call_count
        background_call_count += 1
        return mcp.types.CallToolResult(
            content=[mcp.types.TextContent(type="text", text="background handoff")],
        )

    monkeypatch.setattr(
        FunctionToolExecutor,
        "_execute_handoff",
        fake_execute_handoff,
    )
    monkeypatch.setattr(
        FunctionToolExecutor,
        "_execute_handoff_background",
        fake_execute_handoff_background,
    )

    agent = Agent()
    tool = HandoffTool(name="transfer_to_subagent")
    context = ContextWrapper(agent=agent)
    arguments = {"background_task": background_value}

    result = FunctionToolExecutor.execute(
        agent=agent,
        tool=tool,
        context=context,
        arguments=arguments,
    )

    assert handoff_call_count == 1
    assert background_call_count == 0

    results = result if isinstance(result, (list, tuple)) else [result]
    for res in results:
        if isinstance(res, mcp.types.CallToolResult) and res.content:
            first = res.content[0]
            if isinstance(first, mcp.types.TextContent):
                assert "invalid_background_task" not in first.text

The new tests assume:

  1. FunctionToolExecutor.execute has a signature like:
    execute(agent: Agent, tool: HandoffTool, context: ContextWrapper, arguments: dict).
  2. Agent, HandoffTool, and ContextWrapper can be constructed with the minimal arguments shown.

To integrate with the actual codebase, you may need to:

  1. Adjust the construction of agent, tool, and context to match their real constructors, or switch to using existing fixtures/helpers already used in test_execute_invalid_background_task_early_error.
  2. Update the execute call to match the real parameter names and ordering.
  3. If the real non-error CallToolResult shape differs, tweak the fake implementations of _execute_handoff and _execute_handoff_background to return whatever the production code expects.
  4. If the production error detection differs from checking for "invalid_background_task" in TextContent, align the “no error” assertions with the pattern already used in test_execute_invalid_background_task_early_error.

("false", False, False),
("0", False, False),
("no", False, False),
("off", False, False),
("", False, False),
(" FALSE ", False, False),
(None, False, False),
("not-a-bool", False, True),
("y", False, True),
("t", False, True),
(123, False, True),
({}, False, True),
],
)
def test_parse_background_task_arg(value, expected_bool, expect_error):
is_bg, error = FunctionToolExecutor._parse_background_task_arg(
"transfer_to_subagent",
value,
)

assert is_bg is expected_bool
if expect_error:
assert error is not None
assert isinstance(error, mcp.types.CallToolResult)
text_content = error.content[0]
assert isinstance(text_content, mcp.types.TextContent)
assert "invalid_background_task" in text_content.text
else:
assert error is None


@pytest.mark.asyncio
async def test_execute_invalid_background_task_early_error(monkeypatch):
call_count = {"handoff": 0, "handoff_bg": 0}

async def _fake_execute_handoff(cls, tool, run_context, **tool_args):
call_count["handoff"] += 1
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="unexpected")]
)

async def _fake_execute_handoff_bg(cls, tool, run_context, **tool_args):
call_count["handoff_bg"] += 1
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="unexpected")]
)

monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff",
classmethod(_fake_execute_handoff),
)
monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff_background",
classmethod(_fake_execute_handoff_bg),
)

tool = HandoffTool(agent=Agent(name="subagent"))
event = SimpleNamespace(
unified_msg_origin="webchat:FriendMessage:webchat!user!session",
message_obj=SimpleNamespace(message=[]),
)
run_context = ContextWrapper(
context=SimpleNamespace(event=event, context=SimpleNamespace())
)

results = []
async for result in FunctionToolExecutor.execute(
tool,
run_context,
input="hello",
background_task="not-a-bool",
):
results.append(result)

assert len(results) == 1
assert isinstance(results[0], mcp.types.CallToolResult)
text_content = results[0].content[0]
assert isinstance(text_content, mcp.types.TextContent)
assert "invalid_background_task" in text_content.text
assert call_count["handoff"] == 0
assert call_count["handoff_bg"] == 0