Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ async def reset(
self._abort_signal = asyncio.Event()
self._pending_follow_ups: list[FollowUpTicket] = []
self._follow_up_seq = 0
self._last_tool_name: str | None = None
self._same_tool_streak = 0

# These two are used for tool schema mode handling
# We now have two modes:
Expand Down Expand Up @@ -427,6 +429,41 @@ def _merge_follow_up_notice(self, content: str) -> str:
return content
return f"{content}{notice}"

def _track_tool_call_streak(self, tool_name: str) -> int:
if tool_name == self._last_tool_name:
self._same_tool_streak += 1
else:
self._last_tool_name = tool_name
self._same_tool_streak = 1
return self._same_tool_streak

def _build_same_tool_guidance(self, tool_name: str, streak: int) -> str:
if streak < 3:
return ""
Comment thread
Soulter marked this conversation as resolved.
Outdated

if streak >= 5:
return (
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
"\n\n[SYSTEM NOTICE] Important: you have executed the same tool "
f"`{tool_name}` {streak} times consecutively. Repetition is now very "
"high. Continue only if each call is clearly producing new information. "
"Otherwise, change strategy, adjust arguments, or explain the limitation "
"to the user."
)

if streak >= 3:
return (
"\n\n[SYSTEM NOTICE] Important: you have executed the same tool "
f"`{tool_name}` {streak} times consecutively. Unless this repetition is "
"clearly necessary, stop repeating the same action and either switch "
"tools, refine parameters, or summarize what is still missing."
)

return (
"\n\n[SYSTEM NOTICE] By the way, you have executed the same tool "
f"`{tool_name}` {streak} times consecutively. Double-check whether another "
"tool, different arguments, or a summary would move the task forward better."
)

@override
async def step(self):
"""Process a single step of the agent.
Expand Down Expand Up @@ -712,6 +749,7 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
tool_call_streak = self._track_tool_call_streak(func_tool_name)
yield _HandleFunctionToolsResult.from_message_chain(
MessageChain(
type="tool_call",
Expand Down Expand Up @@ -861,7 +899,10 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
if result_parts:
_append_tool_call_result(
func_tool_id,
"\n\n".join(result_parts),
"\n\n".join(result_parts)
+ self._build_same_tool_guidance(
func_tool_name, tool_call_streak
),
)

elif resp is None:
Expand All @@ -875,7 +916,10 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
self.stats.end_time = time.time()
_append_tool_call_result(
func_tool_id,
"The tool has no return value, or has sent the result directly to the user.",
"The tool has no return value, or has sent the result directly to the user."
+ self._build_same_tool_guidance(
func_tool_name, tool_call_streak
),
)
else:
# 不应该出现其他类型
Expand All @@ -884,7 +928,10 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
)
_append_tool_call_result(
func_tool_id,
"*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
"*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*"
+ self._build_same_tool_guidance(
func_tool_name, tool_call_streak
),
)

try:
Expand All @@ -902,7 +949,8 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
logger.warning(traceback.format_exc())
_append_tool_call_result(
func_tool_id,
f"error: {e!s}",
f"error: {e!s}"
+ self._build_same_tool_guidance(func_tool_name, tool_call_streak),
)

# yield the last tool call result
Expand Down
1 change: 0 additions & 1 deletion dashboard/src/components/provider/ProviderSourcesPanel.vue
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ const emitDeleteSource = (source) => emit('delete-provider-source', source)
}

.provider-source-list {
max-height: calc(100vh - 335px);
overflow-y: auto;
Comment thread
Soulter marked this conversation as resolved.
padding: 0;
background: transparent;
Expand Down
121 changes: 121 additions & 0 deletions tests/test_tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,32 @@ async def text_chat(self, **kwargs) -> LLMResponse:
)


class SequentialToolProvider(MockProvider):
def __init__(self, tool_sequence: list[str]):
super().__init__()
self.tool_sequence = tool_sequence

async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
func_tool = kwargs.get("func_tool")
if func_tool is None or self.call_count > len(self.tool_sequence):
return LLMResponse(
role="assistant",
completion_text="这是我的最终回答",
usage=TokenUsage(input_other=10, output=5),
)

tool_name = self.tool_sequence[self.call_count - 1]
return LLMResponse(
role="assistant",
completion_text="",
tools_call_name=[tool_name],
tools_call_args=[{"query": f"step-{self.call_count}"}],
tools_call_ids=[f"call_{self.call_count}"],
usage=TokenUsage(input_other=10, output=5),
)


class MockHandoffProvider(MockToolCallProvider):
def __init__(self, handoff_tool_name: str):
super().__init__(handoff_tool_name, {"input": "delegate this task"})
Expand Down Expand Up @@ -538,6 +564,101 @@ def fake_save_image(
]


@pytest.mark.asyncio
async def test_same_tool_consecutive_results_include_escalating_guidance(
runner, mock_tool_executor, mock_hooks
):
provider = SequentialToolProvider(["test_tool"] * 5)
tool = FunctionTool(
name="test_tool",
description="测试工具",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
request = ProviderRequest(
prompt="请连续执行工具",
func_tool=ToolSet(tools=[tool]),
contexts=[],
)

await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=False,
)

async for _ in runner.step_until_done(6):
pass

tool_messages = [
m for m in runner.run_context.messages if getattr(m, "role", None) == "tool"
]
assert len(tool_messages) == 5

Comment thread
Soulter marked this conversation as resolved.
tool_contents = [str(message.content) for message in tool_messages]
assert "same tool" not in tool_contents[0]
assert "By the way" in tool_contents[1]
assert "2 times consecutively" in tool_contents[1]
assert "Important" in tool_contents[2]
assert "3 times consecutively" in tool_contents[2]
assert "Important" in tool_contents[4]
assert "5 times consecutively" in tool_contents[4]
assert "very high" in tool_contents[4]


@pytest.mark.asyncio
async def test_same_tool_streak_resets_after_switching_tools(
runner, mock_tool_executor, mock_hooks
):
provider = SequentialToolProvider(
["test_tool", "other_tool", "test_tool", "test_tool"]
)
tool_a = FunctionTool(
name="test_tool",
description="测试工具 A",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
tool_b = FunctionTool(
name="other_tool",
description="测试工具 B",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
request = ProviderRequest(
prompt="切换工具后再重复",
func_tool=ToolSet(tools=[tool_a, tool_b]),
contexts=[],
)

await runner.reset(
provider=provider,
request=request,
run_context=ContextWrapper(context=None),
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=False,
)

async for _ in runner.step_until_done(5):
pass

tool_messages = [
m for m in runner.run_context.messages if getattr(m, "role", None) == "tool"
]
assert len(tool_messages) == 4

tool_contents = [str(message.content) for message in tool_messages]
assert "same tool" not in tool_contents[0]
assert "same tool" not in tool_contents[1]
assert "same tool" not in tool_contents[2]
assert "By the way" in tool_contents[3]
assert "`test_tool` 2 times consecutively" in tool_contents[3]


@pytest.mark.asyncio
async def test_fallback_provider_used_when_primary_raises(
runner, provider_request, mock_tool_executor, mock_hooks
Expand Down
Loading