-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
fix(subagent): enforce handoff args and provider fallback #6873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
f7897d1
15662ae
f7de29a
d8429ae
a6a98d7
2601a7e
1ad91d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,8 @@ def __init__( | |
| def default_parameters(self) -> dict: | ||
| return { | ||
| "type": "object", | ||
| "required": ["input"], | ||
| "additionalProperties": False, | ||
|
Comment on lines
+42
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Making With |
||
| "properties": { | ||
| "input": { | ||
| "type": "string", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -343,3 +343,96 @@ async def _fake_convert_to_file_path(self): | |
| ) | ||
|
|
||
| assert image_urls == [] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_execute_handoff_rejects_empty_input(): | ||
| async def _fake_get_current_chat_provider_id(_umo): | ||
| return "provider-id" | ||
|
|
||
| async def _fake_tool_loop_agent(**_kwargs): | ||
| return SimpleNamespace(completion_text="ok") | ||
|
|
||
| context = SimpleNamespace( | ||
| get_current_chat_provider_id=_fake_get_current_chat_provider_id, | ||
|
Comment on lines
+348
to
+357
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Strengthen empty-input test by asserting the handoff agent is never invoked To fully validate the early-rejection behavior, also assert that Suggested implementation: @pytest.mark.asyncio
async def test_execute_handoff_rejects_empty_input():
tool_loop_called = False
async def _fake_get_current_chat_provider_id(_umo):
return "provider-id"
async def _fake_tool_loop_agent(**_kwargs):
nonlocal tool_loop_called
tool_loop_called = True
return SimpleNamespace(completion_text="ok")
context = SimpleNamespace(
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
tool_loop_agent=_fake_tool_loop_agent,To fully implement the strengthened test, you should also:
|
||
| tool_loop_agent=_fake_tool_loop_agent, | ||
| get_config=lambda **_kwargs: {"provider_settings": {}}, | ||
| ) | ||
| event = _DummyEvent([]) | ||
| run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) | ||
| tool = SimpleNamespace( | ||
| name="transfer_to_subagent", | ||
| provider_id=None, | ||
| agent=SimpleNamespace( | ||
| name="subagent", | ||
| tools=[], | ||
| instructions="subagent-instructions", | ||
| begin_dialogs=[], | ||
| run_hooks=None, | ||
| ), | ||
| ) | ||
|
|
||
| results = [] | ||
| async for result in FunctionToolExecutor._execute_handoff( | ||
| tool, | ||
| run_context, | ||
| image_urls_prepared=True, | ||
| input=" ", | ||
| image_urls=[], | ||
| ): | ||
| results.append(result) | ||
|
|
||
| assert len(results) == 1 | ||
| assert isinstance(results[0], mcp.types.CallToolResult) | ||
| text_content = results[0].content[0] | ||
| assert isinstance(text_content, mcp.types.TextContent) | ||
| assert "missing_or_empty_input" in text_content.text | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_execute_handoff_falls_back_to_current_provider_when_configured_missing(): | ||
| captured: dict = {} | ||
|
|
||
| class _DummyProviderManager: | ||
| async def get_provider_by_id(self, _provider_id: str): | ||
| return None | ||
|
|
||
| async def _fake_get_current_chat_provider_id(_umo): | ||
| return "fallback-provider" | ||
|
|
||
| async def _fake_tool_loop_agent(**kwargs): | ||
| captured.update(kwargs) | ||
| return SimpleNamespace(completion_text="ok") | ||
|
|
||
| context = SimpleNamespace( | ||
| provider_manager=_DummyProviderManager(), | ||
| get_current_chat_provider_id=_fake_get_current_chat_provider_id, | ||
| tool_loop_agent=_fake_tool_loop_agent, | ||
| get_config=lambda **_kwargs: {"provider_settings": {}}, | ||
| ) | ||
| event = _DummyEvent([]) | ||
| run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) | ||
| tool = SimpleNamespace( | ||
| name="transfer_to_subagent", | ||
| provider_id="missing-provider-id", | ||
| agent=SimpleNamespace( | ||
| name="subagent", | ||
| tools=[], | ||
| instructions="subagent-instructions", | ||
| begin_dialogs=[], | ||
| run_hooks=None, | ||
| ), | ||
| ) | ||
|
|
||
| results = [] | ||
| async for result in FunctionToolExecutor._execute_handoff( | ||
| tool, | ||
| run_context, | ||
| image_urls_prepared=True, | ||
| input="hello", | ||
| image_urls=[], | ||
| ): | ||
| results.append(result) | ||
|
|
||
| assert len(results) == 1 | ||
| assert captured["chat_provider_id"] == "fallback-provider" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| from types import SimpleNamespace | ||
|
|
||
| import mcp | ||
| import pytest | ||
|
|
||
| from astrbot.core.agent.agent import Agent | ||
| from astrbot.core.agent.handoff import HandoffTool | ||
| from astrbot.core.agent.run_context import ContextWrapper | ||
| from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("value", "expected_bool", "expect_error"), | ||
| [ | ||
| (True, True, False), | ||
| ("true", True, False), | ||
| ("1", True, False), | ||
| ("yes", True, False), | ||
| ("on", True, False), | ||
| (" TRUE ", True, False), | ||
| (False, False, False), | ||
|
Comment on lines
+24
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Background task parsing is well-covered at the helper level; consider adding an integration-level test for invalid The current parametrized test covers Suggested implementation: import mcp
import pytest
from unittest.mock import AsyncMock
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutorfrom astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
@pytest.mark.asyncio
async def test_execute_invalid_background_task_early_error(monkeypatch):
"""Integration-level test: invalid background_task should yield a single error and not hand off."""
# Arrange: create an executor instance
executor = FunctionToolExecutor()
# If execute delegates into a separate handoff method, guard it so the test fails if it is invoked.
# Adjust the attribute name if the handoff entrypoint differs (e.g. `handoff`, `_handoff`, etc.).
if hasattr(executor, "_execute_handoff"):
monkeypatch.setattr(
executor,
"_execute_handoff",
AsyncMock(side_effect=AssertionError("_execute_handoff should not be called for invalid background_task")),
)
# Act: call execute with an invalid background_task value that should fail parsing
results = [
result async for result in executor.execute(
tool_name="some_tool",
arguments={},
background_task="not-a-bool", # intentionally invalid
)
]
# Assert: exactly one CallToolResult is yielded, it is an error, and no downstream calls are made.
assert len(results) == 1
assert isinstance(results[0], mcp.CallToolResult)
assert getattr(results[0], "is_error", True), "Expected the single CallToolResult to represent an error"
@pytest.mark.parametrize(The new integration-level test assumes:
Comment on lines
+24
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Consider adding dedicated tests for The parameterized helper tests cover value parsing well, but
These can follow the pattern of Suggested implementation: if expect_error:
assert error is not None
assert isinstance(error, mcp.types.CallToolResult)
text_content = error.content[0]
assert isinstance(text_content, mcp.types.TextContent)
assert "invalid_background_task" in text_content.text
@pytest.mark.parametrize("background_value", [True, "true"])
def test_execute_truthy_background_uses_background_handoff(monkeypatch, background_value):
handoff_call_count = 0
background_call_count = 0
def fake_execute_handoff(*args, **kwargs):
nonlocal handoff_call_count
handoff_call_count += 1
# Return a non-error CallToolResult shape; adjust as needed for the real API.
return mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="foreground handoff")],
)
def fake_execute_handoff_background(*args, **kwargs):
nonlocal background_call_count
background_call_count += 1
return mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="background handoff")],
)
monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff",
fake_execute_handoff,
)
monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff_background",
fake_execute_handoff_background,
)
# Construct minimal arguments following the pattern of other tests in this module.
# These may need to be adjusted to match the actual FunctionToolExecutor.execute signature.
agent = Agent()
tool = HandoffTool(name="transfer_to_subagent")
context = ContextWrapper(agent=agent)
arguments = {"background_task": background_value}
result = FunctionToolExecutor.execute(
agent=agent,
tool=tool,
context=context,
arguments=arguments,
)
assert handoff_call_count == 0
assert background_call_count == 1
# Ensure no "invalid_background_task" error CallToolResult was returned.
results = result if isinstance(result, (list, tuple)) else [result]
for res in results:
if isinstance(res, mcp.types.CallToolResult) and res.content:
first = res.content[0]
if isinstance(first, mcp.types.TextContent):
assert "invalid_background_task" not in first.text
@pytest.mark.parametrize("background_value", [False, "false", "0"])
def test_execute_falsey_background_uses_foreground_handoff(monkeypatch, background_value):
handoff_call_count = 0
background_call_count = 0
def fake_execute_handoff(*args, **kwargs):
nonlocal handoff_call_count
handoff_call_count += 1
return mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="foreground handoff")],
)
def fake_execute_handoff_background(*args, **kwargs):
nonlocal background_call_count
background_call_count += 1
return mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text="background handoff")],
)
monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff",
fake_execute_handoff,
)
monkeypatch.setattr(
FunctionToolExecutor,
"_execute_handoff_background",
fake_execute_handoff_background,
)
agent = Agent()
tool = HandoffTool(name="transfer_to_subagent")
context = ContextWrapper(agent=agent)
arguments = {"background_task": background_value}
result = FunctionToolExecutor.execute(
agent=agent,
tool=tool,
context=context,
arguments=arguments,
)
assert handoff_call_count == 1
assert background_call_count == 0
results = result if isinstance(result, (list, tuple)) else [result]
for res in results:
if isinstance(res, mcp.types.CallToolResult) and res.content:
first = res.content[0]
if isinstance(first, mcp.types.TextContent):
assert "invalid_background_task" not in first.textThe new tests assume:
To integrate with the actual codebase, you may need to:
|
||
| ("false", False, False), | ||
| ("0", False, False), | ||
| ("no", False, False), | ||
| ("off", False, False), | ||
| ("", False, False), | ||
| (" FALSE ", False, False), | ||
| (None, False, False), | ||
| ("not-a-bool", False, True), | ||
| ("y", False, True), | ||
| ("t", False, True), | ||
| (123, False, True), | ||
| ({}, False, True), | ||
| ], | ||
| ) | ||
| def test_parse_background_task_arg(value, expected_bool, expect_error): | ||
| is_bg, error = FunctionToolExecutor._parse_background_task_arg( | ||
| "transfer_to_subagent", | ||
| value, | ||
| ) | ||
|
|
||
| assert is_bg is expected_bool | ||
| if expect_error: | ||
| assert error is not None | ||
| assert isinstance(error, mcp.types.CallToolResult) | ||
| text_content = error.content[0] | ||
| assert isinstance(text_content, mcp.types.TextContent) | ||
| assert "invalid_background_task" in text_content.text | ||
| else: | ||
| assert error is None | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_execute_invalid_background_task_early_error(monkeypatch): | ||
| call_count = {"handoff": 0, "handoff_bg": 0} | ||
|
|
||
| async def _fake_execute_handoff(cls, tool, run_context, **tool_args): | ||
| call_count["handoff"] += 1 | ||
| yield mcp.types.CallToolResult( | ||
| content=[mcp.types.TextContent(type="text", text="unexpected")] | ||
| ) | ||
|
|
||
| async def _fake_execute_handoff_bg(cls, tool, run_context, **tool_args): | ||
| call_count["handoff_bg"] += 1 | ||
| yield mcp.types.CallToolResult( | ||
| content=[mcp.types.TextContent(type="text", text="unexpected")] | ||
| ) | ||
|
|
||
| monkeypatch.setattr( | ||
| FunctionToolExecutor, | ||
| "_execute_handoff", | ||
| classmethod(_fake_execute_handoff), | ||
| ) | ||
| monkeypatch.setattr( | ||
| FunctionToolExecutor, | ||
| "_execute_handoff_background", | ||
| classmethod(_fake_execute_handoff_bg), | ||
| ) | ||
|
|
||
| tool = HandoffTool(agent=Agent(name="subagent")) | ||
| event = SimpleNamespace( | ||
| unified_msg_origin="webchat:FriendMessage:webchat!user!session", | ||
| message_obj=SimpleNamespace(message=[]), | ||
| ) | ||
| run_context = ContextWrapper( | ||
| context=SimpleNamespace(event=event, context=SimpleNamespace()) | ||
| ) | ||
|
|
||
| results = [] | ||
| async for result in FunctionToolExecutor.execute( | ||
| tool, | ||
| run_context, | ||
| input="hello", | ||
| background_task="not-a-bool", | ||
| ): | ||
| results.append(result) | ||
|
|
||
| assert len(results) == 1 | ||
| assert isinstance(results[0], mcp.types.CallToolResult) | ||
| text_content = results[0].content[0] | ||
| assert isinstance(text_content, mcp.types.TextContent) | ||
| assert "invalid_background_task" in text_content.text | ||
| assert call_count["handoff"] == 0 | ||
| assert call_count["handoff_bg"] == 0 | ||
Uh oh!
There was an error while loading. Please reload this page.