Skip to content

Commit f25f2ca

Browse files
committed
feat: implement llm guidance for repetition tool call
fixes: #7387
1 parent b0b6816 commit f25f2ca

3 files changed

Lines changed: 173 additions & 5 deletions

File tree

astrbot/core/agent/runners/tool_loop_agent_runner.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ async def reset(
209209
self._abort_signal = asyncio.Event()
210210
self._pending_follow_ups: list[FollowUpTicket] = []
211211
self._follow_up_seq = 0
212+
self._last_tool_name: str | None = None
213+
self._same_tool_streak = 0
212214

213215
# These two are used for tool schema mode handling
214216
# We now have two modes:
@@ -427,6 +429,41 @@ def _merge_follow_up_notice(self, content: str) -> str:
427429
return content
428430
return f"{content}{notice}"
429431

432+
def _track_tool_call_streak(self, tool_name: str) -> int:
433+
if tool_name == self._last_tool_name:
434+
self._same_tool_streak += 1
435+
else:
436+
self._last_tool_name = tool_name
437+
self._same_tool_streak = 1
438+
return self._same_tool_streak
439+
440+
def _build_same_tool_guidance(self, tool_name: str, streak: int) -> str:
441+
if streak < 3:
442+
return ""
443+
444+
if streak >= 5:
445+
return (
446+
"\n\n[SYSTEM NOTICE] Important: you have executed the same tool "
447+
f"`{tool_name}` {streak} times consecutively. Repetition is now very "
448+
"high. Continue only if each call is clearly producing new information. "
449+
"Otherwise, change strategy, adjust arguments, or explain the limitation "
450+
"to the user."
451+
)
452+
453+
if streak >= 3:
454+
return (
455+
"\n\n[SYSTEM NOTICE] Important: you have executed the same tool "
456+
f"`{tool_name}` {streak} times consecutively. Unless this repetition is "
457+
"clearly necessary, stop repeating the same action and either switch "
458+
"tools, refine parameters, or summarize what is still missing."
459+
)
460+
461+
return (
462+
"\n\n[SYSTEM NOTICE] By the way, you have executed the same tool "
463+
f"`{tool_name}` {streak} times consecutively. Double-check whether another "
464+
"tool, different arguments, or a summary would move the task forward better."
465+
)
466+
430467
@override
431468
async def step(self):
432469
"""Process a single step of the agent.
@@ -712,6 +749,7 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
712749
llm_response.tools_call_args,
713750
llm_response.tools_call_ids,
714751
):
752+
tool_call_streak = self._track_tool_call_streak(func_tool_name)
715753
yield _HandleFunctionToolsResult.from_message_chain(
716754
MessageChain(
717755
type="tool_call",
@@ -861,7 +899,10 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
861899
if result_parts:
862900
_append_tool_call_result(
863901
func_tool_id,
864-
"\n\n".join(result_parts),
902+
"\n\n".join(result_parts)
903+
+ self._build_same_tool_guidance(
904+
func_tool_name, tool_call_streak
905+
),
865906
)
866907

867908
elif resp is None:
@@ -875,7 +916,10 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
875916
self.stats.end_time = time.time()
876917
_append_tool_call_result(
877918
func_tool_id,
878-
"The tool has no return value, or has sent the result directly to the user.",
919+
"The tool has no return value, or has sent the result directly to the user."
920+
+ self._build_same_tool_guidance(
921+
func_tool_name, tool_call_streak
922+
),
879923
)
880924
else:
881925
# 不应该出现其他类型
@@ -884,7 +928,10 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
884928
)
885929
_append_tool_call_result(
886930
func_tool_id,
887-
"*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
931+
"*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*"
932+
+ self._build_same_tool_guidance(
933+
func_tool_name, tool_call_streak
934+
),
888935
)
889936

890937
try:
@@ -902,7 +949,8 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
902949
logger.warning(traceback.format_exc())
903950
_append_tool_call_result(
904951
func_tool_id,
905-
f"error: {e!s}",
952+
f"error: {e!s}"
953+
+ self._build_same_tool_guidance(func_tool_name, tool_call_streak),
906954
)
907955

908956
# yield the last tool call result

dashboard/src/components/provider/ProviderSourcesPanel.vue

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ const emitDeleteSource = (source) => emit('delete-provider-source', source)
224224
}
225225
226226
.provider-source-list {
227-
max-height: calc(100vh - 335px);
228227
overflow-y: auto;
229228
padding: 0;
230229
background: transparent;

tests/test_tool_loop_agent_runner.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,32 @@ async def text_chat(self, **kwargs) -> LLMResponse:
193193
)
194194

195195

196+
class SequentialToolProvider(MockProvider):
197+
def __init__(self, tool_sequence: list[str]):
198+
super().__init__()
199+
self.tool_sequence = tool_sequence
200+
201+
async def text_chat(self, **kwargs) -> LLMResponse:
202+
self.call_count += 1
203+
func_tool = kwargs.get("func_tool")
204+
if func_tool is None or self.call_count > len(self.tool_sequence):
205+
return LLMResponse(
206+
role="assistant",
207+
completion_text="这是我的最终回答",
208+
usage=TokenUsage(input_other=10, output=5),
209+
)
210+
211+
tool_name = self.tool_sequence[self.call_count - 1]
212+
return LLMResponse(
213+
role="assistant",
214+
completion_text="",
215+
tools_call_name=[tool_name],
216+
tools_call_args=[{"query": f"step-{self.call_count}"}],
217+
tools_call_ids=[f"call_{self.call_count}"],
218+
usage=TokenUsage(input_other=10, output=5),
219+
)
220+
221+
196222
class MockHandoffProvider(MockToolCallProvider):
197223
def __init__(self, handoff_tool_name: str):
198224
super().__init__(handoff_tool_name, {"input": "delegate this task"})
@@ -538,6 +564,101 @@ def fake_save_image(
538564
]
539565

540566

567+
@pytest.mark.asyncio
568+
async def test_same_tool_consecutive_results_include_escalating_guidance(
569+
runner, mock_tool_executor, mock_hooks
570+
):
571+
provider = SequentialToolProvider(["test_tool"] * 5)
572+
tool = FunctionTool(
573+
name="test_tool",
574+
description="测试工具",
575+
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
576+
handler=AsyncMock(),
577+
)
578+
request = ProviderRequest(
579+
prompt="请连续执行工具",
580+
func_tool=ToolSet(tools=[tool]),
581+
contexts=[],
582+
)
583+
584+
await runner.reset(
585+
provider=provider,
586+
request=request,
587+
run_context=ContextWrapper(context=None),
588+
tool_executor=mock_tool_executor,
589+
agent_hooks=mock_hooks,
590+
streaming=False,
591+
)
592+
593+
async for _ in runner.step_until_done(6):
594+
pass
595+
596+
tool_messages = [
597+
m for m in runner.run_context.messages if getattr(m, "role", None) == "tool"
598+
]
599+
assert len(tool_messages) == 5
600+
601+
tool_contents = [str(message.content) for message in tool_messages]
602+
assert "same tool" not in tool_contents[0]
603+
assert "By the way" in tool_contents[1]
604+
assert "2 times consecutively" in tool_contents[1]
605+
assert "Important" in tool_contents[2]
606+
assert "3 times consecutively" in tool_contents[2]
607+
assert "Important" in tool_contents[4]
608+
assert "5 times consecutively" in tool_contents[4]
609+
assert "very high" in tool_contents[4]
610+
611+
612+
@pytest.mark.asyncio
613+
async def test_same_tool_streak_resets_after_switching_tools(
614+
runner, mock_tool_executor, mock_hooks
615+
):
616+
provider = SequentialToolProvider(
617+
["test_tool", "other_tool", "test_tool", "test_tool"]
618+
)
619+
tool_a = FunctionTool(
620+
name="test_tool",
621+
description="测试工具 A",
622+
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
623+
handler=AsyncMock(),
624+
)
625+
tool_b = FunctionTool(
626+
name="other_tool",
627+
description="测试工具 B",
628+
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
629+
handler=AsyncMock(),
630+
)
631+
request = ProviderRequest(
632+
prompt="切换工具后再重复",
633+
func_tool=ToolSet(tools=[tool_a, tool_b]),
634+
contexts=[],
635+
)
636+
637+
await runner.reset(
638+
provider=provider,
639+
request=request,
640+
run_context=ContextWrapper(context=None),
641+
tool_executor=mock_tool_executor,
642+
agent_hooks=mock_hooks,
643+
streaming=False,
644+
)
645+
646+
async for _ in runner.step_until_done(5):
647+
pass
648+
649+
tool_messages = [
650+
m for m in runner.run_context.messages if getattr(m, "role", None) == "tool"
651+
]
652+
assert len(tool_messages) == 4
653+
654+
tool_contents = [str(message.content) for message in tool_messages]
655+
assert "same tool" not in tool_contents[0]
656+
assert "same tool" not in tool_contents[1]
657+
assert "same tool" not in tool_contents[2]
658+
assert "By the way" in tool_contents[3]
659+
assert "`test_tool` 2 times consecutively" in tool_contents[3]
660+
661+
541662
@pytest.mark.asyncio
542663
async def test_fallback_provider_used_when_primary_raises(
543664
runner, provider_request, mock_tool_executor, mock_hooks

0 commit comments

Comments
 (0)