Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions astrbot/core/astr_agent_run_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,38 @@ def _build_tool_result_status_message(
return status_msg


def _normalize_repeat_reply_guard_threshold(value: int) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
return 0
return max(0, parsed)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这个 _normalize_repeat_reply_guard_threshold 标准化函数是多余的。因为 repeat_reply_guard_threshold 参数在传入 run_agent 之前,已经在 InternalAgentSubStage.initialize 中被处理过了。为了避免代码重复并简化逻辑,建议移除此函数。



def _build_chain_signature(msg_chain: MessageChain) -> str:
signature = msg_chain.get_plain_text(with_other_comps_mark=True).strip()
if not signature:
return ""
return re.sub(r"\s+", " ", signature)


async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
show_tool_call_result: bool = False,
stream_to_general: bool = False,
show_reasoning: bool = False,
repeat_reply_guard_threshold: int = 3,
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
tool_name_by_call_id: dict[str, str] = {}
guard_threshold = _normalize_repeat_reply_guard_threshold(
repeat_reply_guard_threshold
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

由于 repeat_reply_guard_threshold 的值在传入此函数前已经经过了标准化处理,因此这里对 _normalize_repeat_reply_guard_threshold 的调用是不必要的。你可以直接进行赋值。

    guard_threshold = repeat_reply_guard_threshold

guard_last_signature = ""
guard_repeat_count = 0
while step_idx < max_step + 1:
step_idx += 1

Expand Down Expand Up @@ -193,6 +214,38 @@ async def run_agent(
await astr_event.send(chain)
continue

if resp.type == "llm_result" and guard_threshold > 0:
chain_signature = _build_chain_signature(resp.data["chain"])
if chain_signature:
if chain_signature == guard_last_signature:
guard_repeat_count += 1
else:
guard_last_signature = chain_signature
guard_repeat_count = 1

if guard_repeat_count >= guard_threshold:
logger.warning(
"Agent repeated identical llm_result %d times; forcing convergence. threshold=%d",
guard_repeat_count,
guard_threshold,
)
if not agent_runner.done():
if agent_runner.req:
agent_runner.req.func_tool = None
agent_runner.run_context.messages.append(
Message(
role="user",
content=(
"检测到你连续多次输出相同回复。"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用英文

"请停止重复,基于已有信息给出最终答复,"
"不要再次调用工具。"
),
)
)
# Jump to the same convergence path as max-step limit.
step_idx = max_step
continue

if stream_to_general and resp.type == "streaming_delta":
continue

Expand Down Expand Up @@ -288,6 +341,7 @@ async def run_live_agent(
show_tool_use: bool = True,
show_tool_call_result: bool = False,
show_reasoning: bool = False,
repeat_reply_guard_threshold: int = 3,
) -> AsyncGenerator[MessageChain | None, None]:
"""Live Mode 的 Agent 运行器,支持流式 TTS

Expand All @@ -311,6 +365,7 @@ async def run_live_agent(
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
repeat_reply_guard_threshold=repeat_reply_guard_threshold,
):
yield chain
return
Expand Down Expand Up @@ -343,6 +398,7 @@ async def run_live_agent(
show_tool_use,
show_tool_call_result,
show_reasoning,
repeat_reply_guard_threshold,
)
)

Expand Down Expand Up @@ -430,6 +486,7 @@ async def _run_agent_feeder(
show_tool_use: bool,
show_tool_call_result: bool,
show_reasoning: bool,
repeat_reply_guard_threshold: int,
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
buffer = ""
Expand All @@ -441,6 +498,7 @@ async def _run_agent_feeder(
show_tool_call_result=show_tool_call_result,
stream_to_general=False,
show_reasoning=show_reasoning,
repeat_reply_guard_threshold=repeat_reply_guard_threshold,
):
if chain is None:
continue
Expand Down
12 changes: 12 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
"unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": False,
"max_agent_step": 30,
"repeat_reply_guard_threshold": 3,
"tool_call_timeout": 120,
"tool_schema_mode": "full",
"llm_safety_mode": True,
Expand Down Expand Up @@ -2685,6 +2686,9 @@ class ChatProviderTemplate(TypedDict):
"max_agent_step": {
"type": "int",
},
"repeat_reply_guard_threshold": {
"type": "int",
},
"tool_call_timeout": {
"type": "int",
},
Expand Down Expand Up @@ -3430,6 +3434,14 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.repeat_reply_guard_threshold": {
"description": "连续相同回复拦截阈值",
"type": "int",
"hint": "同一轮 Agent 运行中连续出现相同回复达到该次数时,将触发防循环收敛。设置为 0 可关闭。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.tool_schema_mode = "full"
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.repeat_reply_guard_threshold: int = settings.get(
"repeat_reply_guard_threshold", 3
)
if isinstance(self.repeat_reply_guard_threshold, bool):
self.repeat_reply_guard_threshold = 3
try:
self.repeat_reply_guard_threshold = int(self.repeat_reply_guard_threshold)
except (TypeError, ValueError):
self.repeat_reply_guard_threshold = 3
if self.repeat_reply_guard_threshold < 0:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
self.repeat_reply_guard_threshold = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分用于规范化 repeat_reply_guard_threshold 的逻辑有些冗长,可以简化以提高可读性。

Suggested change
if isinstance(self.repeat_reply_guard_threshold, bool):
self.repeat_reply_guard_threshold = 3
try:
self.repeat_reply_guard_threshold = int(self.repeat_reply_guard_threshold)
except (TypeError, ValueError):
self.repeat_reply_guard_threshold = 3
if self.repeat_reply_guard_threshold < 0:
self.repeat_reply_guard_threshold = 0
try:
if isinstance(self.repeat_reply_guard_threshold, bool):
raise TypeError
parsed_val = int(self.repeat_reply_guard_threshold)
self.repeat_reply_guard_threshold = max(0, parsed_val)
except (TypeError, ValueError):
self.repeat_reply_guard_threshold = 3

self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.show_tool_call_result: bool = settings.get("show_tool_call_result", False)
self.show_reasoning = settings.get("display_reasoning_text", False)
Expand Down Expand Up @@ -274,6 +285,7 @@ async def process(
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
repeat_reply_guard_threshold=self.repeat_reply_guard_threshold,
),
),
)
Expand Down Expand Up @@ -304,6 +316,7 @@ async def process(
self.show_tool_use,
self.show_tool_call_result,
show_reasoning=self.show_reasoning,
repeat_reply_guard_threshold=self.repeat_reply_guard_threshold,
),
),
)
Expand Down Expand Up @@ -334,6 +347,7 @@ async def process(
self.show_tool_call_result,
stream_to_general,
show_reasoning=self.show_reasoning,
repeat_reply_guard_threshold=self.repeat_reply_guard_threshold,
):
yield

Expand Down
148 changes: 148 additions & 0 deletions tests/unit/test_astr_agent_run_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from types import SimpleNamespace

import pytest

from astrbot.core.astr_agent_run_util import run_agent
from astrbot.core.message.message_event_result import MessageChain


def _llm_result_response(text: str):
return SimpleNamespace(
type="llm_result",
data={"chain": MessageChain().message(text)},
)


class _DummyTrace:
def record(self, *args, **kwargs) -> None:
return None


class _DummyEvent:
def __init__(self) -> None:
self._extras: dict = {}
self._stopped = False
self.result_texts: list[str] = []
self.trace = _DummyTrace()

def is_stopped(self) -> bool:
return self._stopped

def get_extra(self, key: str, default=None):
return self._extras.get(key, default)

def set_extra(self, key: str, value) -> None:
self._extras[key] = value

def set_result(self, result) -> None:
self.result_texts.append(result.get_plain_text(with_other_comps_mark=True))

def clear_result(self) -> None:
return None

def get_platform_name(self) -> str:
return "slack"

def get_platform_id(self) -> str:
return "slack"

async def send(self, _msg_chain) -> None:
return None


class _FakeRunner:
def __init__(self, steps: list[list[SimpleNamespace]]) -> None:
self._steps = steps
self._step_idx = 0
self._done = False
self.streaming = False
self.req = SimpleNamespace(func_tool=object())
self.run_context = SimpleNamespace(
context=SimpleNamespace(event=_DummyEvent()),
messages=[],
)
self.stats = SimpleNamespace(to_dict=lambda: {})

def done(self) -> bool:
return self._done

def request_stop(self) -> None:
self.run_context.context.event.set_extra("agent_stop_requested", True)

def was_aborted(self) -> bool:
return False

async def step(self):
if self._step_idx >= len(self._steps):
self._done = True
return

current = self._steps[self._step_idx]
self._step_idx += 1
for resp in current:
yield resp

if self._step_idx >= len(self._steps):
self._done = True


@pytest.mark.asyncio
async def test_repeat_reply_guard_forces_convergence():
runner = _FakeRunner(
[
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("最终答案")],
]
)

async for _ in run_agent(
runner,
max_step=8,
show_tool_use=False,
show_tool_call_result=False,
repeat_reply_guard_threshold=3,
):
pass

assert runner.run_context.context.event.result_texts == [
"重复输出",
"重复输出",
"最终答案",
]
assert runner.req.func_tool is None
assert any(
msg.role == "user" and "检测到你连续多次输出相同回复" in str(msg.content)
for msg in runner.run_context.messages
)


@pytest.mark.asyncio
async def test_repeat_reply_guard_can_be_disabled_with_zero_threshold():
runner = _FakeRunner(
[
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("重复输出")],
[_llm_result_response("最终答案")],
]
)
original_func_tool = runner.req.func_tool

async for _ in run_agent(
runner,
max_step=8,
show_tool_use=False,
show_tool_call_result=False,
repeat_reply_guard_threshold=0,
):
pass

assert runner.run_context.context.event.result_texts == [
"重复输出",
"重复输出",
"重复输出",
"最终答案",
]
assert runner.req.func_tool is original_func_tool