-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
fix(agent): add configurable repeated-reply convergence guard #6921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,17 +87,38 @@ def _build_tool_result_status_message( | |
| return status_msg | ||
|
|
||
|
|
||
| def _normalize_repeat_reply_guard_threshold(value: int) -> int: | ||
| try: | ||
| parsed = int(value) | ||
| except (TypeError, ValueError): | ||
| return 0 | ||
| return max(0, parsed) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| def _build_chain_signature(msg_chain: MessageChain) -> str: | ||
| signature = msg_chain.get_plain_text(with_other_comps_mark=True).strip() | ||
| if not signature: | ||
| return "" | ||
| return re.sub(r"\s+", " ", signature) | ||
|
|
||
|
|
||
| async def run_agent( | ||
| agent_runner: AgentRunner, | ||
| max_step: int = 30, | ||
| show_tool_use: bool = True, | ||
| show_tool_call_result: bool = False, | ||
| stream_to_general: bool = False, | ||
| show_reasoning: bool = False, | ||
| repeat_reply_guard_threshold: int = 3, | ||
| ) -> AsyncGenerator[MessageChain | None, None]: | ||
| step_idx = 0 | ||
| astr_event = agent_runner.run_context.context.event | ||
| tool_name_by_call_id: dict[str, str] = {} | ||
| guard_threshold = _normalize_repeat_reply_guard_threshold( | ||
| repeat_reply_guard_threshold | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| guard_last_signature = "" | ||
| guard_repeat_count = 0 | ||
| while step_idx < max_step + 1: | ||
| step_idx += 1 | ||
|
|
||
|
|
@@ -193,6 +214,38 @@ async def run_agent( | |
| await astr_event.send(chain) | ||
| continue | ||
|
|
||
| if resp.type == "llm_result" and guard_threshold > 0: | ||
| chain_signature = _build_chain_signature(resp.data["chain"]) | ||
| if chain_signature: | ||
| if chain_signature == guard_last_signature: | ||
| guard_repeat_count += 1 | ||
| else: | ||
| guard_last_signature = chain_signature | ||
| guard_repeat_count = 1 | ||
|
|
||
| if guard_repeat_count >= guard_threshold: | ||
| logger.warning( | ||
| "Agent repeated identical llm_result %d times; forcing convergence. threshold=%d", | ||
| guard_repeat_count, | ||
| guard_threshold, | ||
| ) | ||
| if not agent_runner.done(): | ||
| if agent_runner.req: | ||
| agent_runner.req.func_tool = None | ||
| agent_runner.run_context.messages.append( | ||
| Message( | ||
| role="user", | ||
| content=( | ||
| "检测到你连续多次输出相同回复。" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议使用英文 |
||
| "请停止重复,基于已有信息给出最终答复," | ||
| "不要再次调用工具。" | ||
| ), | ||
| ) | ||
| ) | ||
| # Jump to the same convergence path as max-step limit. | ||
| step_idx = max_step | ||
| continue | ||
|
|
||
| if stream_to_general and resp.type == "streaming_delta": | ||
| continue | ||
|
|
||
|
|
@@ -288,6 +341,7 @@ async def run_live_agent( | |
| show_tool_use: bool = True, | ||
| show_tool_call_result: bool = False, | ||
| show_reasoning: bool = False, | ||
| repeat_reply_guard_threshold: int = 3, | ||
| ) -> AsyncGenerator[MessageChain | None, None]: | ||
| """Live Mode 的 Agent 运行器,支持流式 TTS | ||
|
|
||
|
|
@@ -311,6 +365,7 @@ async def run_live_agent( | |
| show_tool_call_result=show_tool_call_result, | ||
| stream_to_general=False, | ||
| show_reasoning=show_reasoning, | ||
| repeat_reply_guard_threshold=repeat_reply_guard_threshold, | ||
| ): | ||
| yield chain | ||
| return | ||
|
|
@@ -343,6 +398,7 @@ async def run_live_agent( | |
| show_tool_use, | ||
| show_tool_call_result, | ||
| show_reasoning, | ||
| repeat_reply_guard_threshold, | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -430,6 +486,7 @@ async def _run_agent_feeder( | |
| show_tool_use: bool, | ||
| show_tool_call_result: bool, | ||
| show_reasoning: bool, | ||
| repeat_reply_guard_threshold: int, | ||
| ) -> None: | ||
| """运行 Agent 并将文本输出分句放入队列""" | ||
| buffer = "" | ||
|
|
@@ -441,6 +498,7 @@ async def _run_agent_feeder( | |
| show_tool_call_result=show_tool_call_result, | ||
| stream_to_general=False, | ||
| show_reasoning=show_reasoning, | ||
| repeat_reply_guard_threshold=repeat_reply_guard_threshold, | ||
| ): | ||
| if chain is None: | ||
| continue | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -64,6 +64,17 @@ async def initialize(self, ctx: PipelineContext) -> None: | |||||||||||||||||||||||||||||||
| self.tool_schema_mode = "full" | ||||||||||||||||||||||||||||||||
| if isinstance(self.max_step, bool): # workaround: #2622 | ||||||||||||||||||||||||||||||||
| self.max_step = 30 | ||||||||||||||||||||||||||||||||
| self.repeat_reply_guard_threshold: int = settings.get( | ||||||||||||||||||||||||||||||||
| "repeat_reply_guard_threshold", 3 | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| if isinstance(self.repeat_reply_guard_threshold, bool): | ||||||||||||||||||||||||||||||||
| self.repeat_reply_guard_threshold = 3 | ||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||
| self.repeat_reply_guard_threshold = int(self.repeat_reply_guard_threshold) | ||||||||||||||||||||||||||||||||
| except (TypeError, ValueError): | ||||||||||||||||||||||||||||||||
| self.repeat_reply_guard_threshold = 3 | ||||||||||||||||||||||||||||||||
| if self.repeat_reply_guard_threshold < 0: | ||||||||||||||||||||||||||||||||
|
sourcery-ai[bot] marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||
| self.repeat_reply_guard_threshold = 0 | ||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分用于规范化
Suggested change
|
||||||||||||||||||||||||||||||||
| self.show_tool_use: bool = settings.get("show_tool_use_status", True) | ||||||||||||||||||||||||||||||||
| self.show_tool_call_result: bool = settings.get("show_tool_call_result", False) | ||||||||||||||||||||||||||||||||
| self.show_reasoning = settings.get("display_reasoning_text", False) | ||||||||||||||||||||||||||||||||
|
|
@@ -274,6 +285,7 @@ async def process( | |||||||||||||||||||||||||||||||
| self.show_tool_use, | ||||||||||||||||||||||||||||||||
| self.show_tool_call_result, | ||||||||||||||||||||||||||||||||
| show_reasoning=self.show_reasoning, | ||||||||||||||||||||||||||||||||
| repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
@@ -304,6 +316,7 @@ async def process( | |||||||||||||||||||||||||||||||
| self.show_tool_use, | ||||||||||||||||||||||||||||||||
| self.show_tool_call_result, | ||||||||||||||||||||||||||||||||
| show_reasoning=self.show_reasoning, | ||||||||||||||||||||||||||||||||
| repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
@@ -334,6 +347,7 @@ async def process( | |||||||||||||||||||||||||||||||
| self.show_tool_call_result, | ||||||||||||||||||||||||||||||||
| stream_to_general, | ||||||||||||||||||||||||||||||||
| show_reasoning=self.show_reasoning, | ||||||||||||||||||||||||||||||||
| repeat_reply_guard_threshold=self.repeat_reply_guard_threshold, | ||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||
| yield | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
|
|
||
| from astrbot.core.astr_agent_run_util import run_agent | ||
| from astrbot.core.message.message_event_result import MessageChain | ||
|
|
||
|
|
||
| def _llm_result_response(text: str): | ||
| return SimpleNamespace( | ||
| type="llm_result", | ||
| data={"chain": MessageChain().message(text)}, | ||
| ) | ||
|
|
||
|
|
||
| class _DummyTrace: | ||
| def record(self, *args, **kwargs) -> None: | ||
| return None | ||
|
|
||
|
|
||
| class _DummyEvent: | ||
| def __init__(self) -> None: | ||
| self._extras: dict = {} | ||
| self._stopped = False | ||
| self.result_texts: list[str] = [] | ||
| self.trace = _DummyTrace() | ||
|
|
||
| def is_stopped(self) -> bool: | ||
| return self._stopped | ||
|
|
||
| def get_extra(self, key: str, default=None): | ||
| return self._extras.get(key, default) | ||
|
|
||
| def set_extra(self, key: str, value) -> None: | ||
| self._extras[key] = value | ||
|
|
||
| def set_result(self, result) -> None: | ||
| self.result_texts.append(result.get_plain_text(with_other_comps_mark=True)) | ||
|
|
||
| def clear_result(self) -> None: | ||
| return None | ||
|
|
||
| def get_platform_name(self) -> str: | ||
| return "slack" | ||
|
|
||
| def get_platform_id(self) -> str: | ||
| return "slack" | ||
|
|
||
| async def send(self, _msg_chain) -> None: | ||
| return None | ||
|
|
||
|
|
||
| class _FakeRunner: | ||
| def __init__(self, steps: list[list[SimpleNamespace]]) -> None: | ||
| self._steps = steps | ||
| self._step_idx = 0 | ||
| self._done = False | ||
| self.streaming = False | ||
| self.req = SimpleNamespace(func_tool=object()) | ||
| self.run_context = SimpleNamespace( | ||
| context=SimpleNamespace(event=_DummyEvent()), | ||
| messages=[], | ||
| ) | ||
| self.stats = SimpleNamespace(to_dict=lambda: {}) | ||
|
|
||
| def done(self) -> bool: | ||
| return self._done | ||
|
|
||
| def request_stop(self) -> None: | ||
| self.run_context.context.event.set_extra("agent_stop_requested", True) | ||
|
|
||
| def was_aborted(self) -> bool: | ||
| return False | ||
|
|
||
| async def step(self): | ||
| if self._step_idx >= len(self._steps): | ||
| self._done = True | ||
| return | ||
|
|
||
| current = self._steps[self._step_idx] | ||
| self._step_idx += 1 | ||
| for resp in current: | ||
| yield resp | ||
|
|
||
| if self._step_idx >= len(self._steps): | ||
| self._done = True | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_repeat_reply_guard_forces_convergence(): | ||
| runner = _FakeRunner( | ||
| [ | ||
| [_llm_result_response("重复输出")], | ||
| [_llm_result_response("重复输出")], | ||
| [_llm_result_response("重复输出")], | ||
| [_llm_result_response("最终答案")], | ||
| ] | ||
| ) | ||
|
|
||
| async for _ in run_agent( | ||
| runner, | ||
| max_step=8, | ||
| show_tool_use=False, | ||
| show_tool_call_result=False, | ||
| repeat_reply_guard_threshold=3, | ||
| ): | ||
| pass | ||
|
|
||
| assert runner.run_context.context.event.result_texts == [ | ||
| "重复输出", | ||
| "重复输出", | ||
| "最终答案", | ||
| ] | ||
| assert runner.req.func_tool is None | ||
| assert any( | ||
| msg.role == "user" and "检测到你连续多次输出相同回复" in str(msg.content) | ||
| for msg in runner.run_context.messages | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_repeat_reply_guard_can_be_disabled_with_zero_threshold(): | ||
| runner = _FakeRunner( | ||
| [ | ||
| [_llm_result_response("重复输出")], | ||
| [_llm_result_response("重复输出")], | ||
| [_llm_result_response("重复输出")], | ||
| [_llm_result_response("最终答案")], | ||
| ] | ||
| ) | ||
| original_func_tool = runner.req.func_tool | ||
|
|
||
| async for _ in run_agent( | ||
| runner, | ||
| max_step=8, | ||
| show_tool_use=False, | ||
| show_tool_call_result=False, | ||
| repeat_reply_guard_threshold=0, | ||
| ): | ||
| pass | ||
|
|
||
| assert runner.run_context.context.event.result_texts == [ | ||
| "重复输出", | ||
| "重复输出", | ||
| "重复输出", | ||
| "最终答案", | ||
| ] | ||
| assert runner.req.func_tool is original_func_tool |
Uh oh!
There was an error while loading. Please reload this page.