Skip to content

Commit bb1b6a3

Browse files
committed
feat: implement tool call approval mechanism with dynamic code strategy
1 parent 28bfb3b commit bb1b6a3

7 files changed

Lines changed: 482 additions & 0 deletions

File tree

astrbot/core/agent/run_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class ContextWrapper(Generic[TContext]):
1717
messages: list[Message] = Field(default_factory=list)
1818
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
1919
tool_call_timeout: int = 60 # Default tool call timeout in seconds
20+
tool_call_approval: dict[str, Any] = Field(default_factory=dict)
21+
"""Tool call approval runtime configuration."""
2022

2123

2224
NoContext = ContextWrapper[None]

astrbot/core/agent/runners/tool_loop_agent_runner.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
3838
from ..response import AgentResponseData, AgentStats
3939
from ..run_context import ContextWrapper, TContext
40+
from ..tool_call_approval import (
41+
ToolCallApprovalContext,
42+
request_tool_call_approval,
43+
)
4044
from ..tool_executor import BaseFunctionToolExecutor
4145
from .base import AgentResponse, AgentState, BaseAgentRunner
4246

@@ -659,6 +663,41 @@ async def _handle_function_tools(
659663
# 如果没有 handler(如 MCP 工具),使用所有参数
660664
valid_params = func_tool_args
661665

666+
approval_cfg = self.run_context.tool_call_approval
667+
if approval_cfg.get("enable", False):
668+
event = getattr(self.run_context.context, "event", None)
669+
if event is None:
670+
tool_call_result_blocks.append(
671+
ToolCallMessageSegment(
672+
role="tool",
673+
tool_call_id=func_tool_id,
674+
content=(
675+
f"error: tool call approval is enabled, but event context is unavailable for `{func_tool_name}`."
676+
),
677+
),
678+
)
679+
continue
680+
approval_result = await request_tool_call_approval(
681+
config=approval_cfg,
682+
ctx=ToolCallApprovalContext(
683+
event=event,
684+
tool_name=func_tool_name,
685+
tool_args=valid_params,
686+
tool_call_id=func_tool_id,
687+
),
688+
)
689+
if not approval_result.approved:
690+
tool_call_result_blocks.append(
691+
ToolCallMessageSegment(
692+
role="tool",
693+
tool_call_id=func_tool_id,
694+
content=approval_result.to_tool_result_text(
695+
func_tool_name
696+
),
697+
),
698+
)
699+
continue
700+
662701
try:
663702
await self.agent_hooks.on_tool_start(
664703
self.run_context,
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from __future__ import annotations
2+
3+
import secrets
4+
import string
5+
import typing as T
6+
from abc import ABC, abstractmethod
7+
from dataclasses import dataclass
8+
9+
from astrbot import logger
10+
from astrbot.core.message.message_event_result import MessageChain
11+
from astrbot.core.platform.astr_message_event import AstrMessageEvent
12+
from astrbot.core.utils.session_waiter import (
13+
FILTERS,
14+
DefaultSessionFilter,
15+
SessionController,
16+
SessionWaiter,
17+
)
18+
19+
ApprovalReason = T.Literal[
20+
"approved",
21+
"rejected",
22+
"timeout",
23+
"unsupported_strategy",
24+
"error",
25+
]
26+
27+
28+
@dataclass(slots=True)
29+
class ToolCallApprovalContext:
30+
event: AstrMessageEvent
31+
tool_name: str
32+
tool_args: dict[str, T.Any]
33+
tool_call_id: str
34+
35+
36+
@dataclass(slots=True)
37+
class ToolCallApprovalResult:
38+
approved: bool
39+
reason: ApprovalReason
40+
detail: str = ""
41+
42+
def to_tool_result_text(self, tool_name: str) -> str:
43+
if self.approved:
44+
return f"tool call approval passed: {tool_name}"
45+
if self.reason == "timeout":
46+
return (
47+
f"error: tool call approval timed out for `{tool_name}`. "
48+
"The tool call was cancelled."
49+
)
50+
if self.reason == "unsupported_strategy":
51+
return (
52+
f"error: tool call approval strategy is unsupported for `{tool_name}`. "
53+
"The tool call was cancelled."
54+
)
55+
if self.reason == "error":
56+
return (
57+
f"error: tool call approval failed for `{tool_name}` ({self.detail}). "
58+
"The tool call was cancelled."
59+
)
60+
return (
61+
f"error: user rejected tool call approval for `{tool_name}`. "
62+
"The tool call was cancelled."
63+
)
64+
65+
66+
class BaseToolCallApprovalStrategy(ABC):
67+
@property
68+
@abstractmethod
69+
def name(self) -> str: ...
70+
71+
@abstractmethod
72+
async def request(
73+
self,
74+
ctx: ToolCallApprovalContext,
75+
config: dict[str, T.Any],
76+
) -> ToolCallApprovalResult: ...
77+
78+
79+
class DynamicCodeApprovalStrategy(BaseToolCallApprovalStrategy):
80+
@property
81+
def name(self) -> str:
82+
return "dynamic_code"
83+
84+
async def request(
85+
self,
86+
ctx: ToolCallApprovalContext,
87+
config: dict[str, T.Any],
88+
) -> ToolCallApprovalResult:
89+
timeout_seconds = _safe_int(config.get("timeout", 60), default=60, minimum=1)
90+
dynamic_cfg = config.get("dynamic_code", {})
91+
if not isinstance(dynamic_cfg, dict):
92+
dynamic_cfg = {}
93+
code_length = _safe_int(dynamic_cfg.get("code_length", 6), default=6, minimum=4)
94+
case_sensitive = bool(dynamic_cfg.get("case_sensitive", False))
95+
96+
code = "".join(secrets.choice(string.digits) for _ in range(code_length))
97+
98+
await ctx.event.send(
99+
MessageChain().message(
100+
"Tool call needs your approval before execution.\n"
101+
f"Tool: `{ctx.tool_name}`\n"
102+
f"Approval code: `{code}`\n"
103+
"Please send this code to continue. "
104+
"Any other message will cancel this tool call."
105+
)
106+
)
107+
108+
try:
109+
result = await _wait_for_code_input(
110+
event=ctx.event,
111+
expected_code=code,
112+
timeout=timeout_seconds,
113+
case_sensitive=case_sensitive,
114+
)
115+
except Exception as exc: # noqa: BLE001
116+
logger.error(
117+
"Tool call approval failed unexpectedly for %s: %s",
118+
ctx.tool_name,
119+
exc,
120+
exc_info=True,
121+
)
122+
return ToolCallApprovalResult(
123+
approved=False,
124+
reason="error",
125+
detail=str(exc),
126+
)
127+
128+
if not result.approved:
129+
if result.reason == "timeout":
130+
await ctx.event.send(
131+
MessageChain().message(
132+
f"Tool call `{ctx.tool_name}` approval timed out. This call was cancelled."
133+
)
134+
)
135+
else:
136+
await ctx.event.send(
137+
MessageChain().message(
138+
f"Tool call `{ctx.tool_name}` was cancelled."
139+
)
140+
)
141+
return result
142+
143+
144+
_STRATEGY_REGISTRY: dict[str, BaseToolCallApprovalStrategy] = {}
145+
146+
147+
def register_tool_call_approval_strategy(
148+
strategy: BaseToolCallApprovalStrategy,
149+
) -> None:
150+
_STRATEGY_REGISTRY[strategy.name] = strategy
151+
152+
153+
def _register_builtin_strategies() -> None:
154+
register_tool_call_approval_strategy(DynamicCodeApprovalStrategy())
155+
156+
157+
_register_builtin_strategies()
158+
159+
160+
async def request_tool_call_approval(
161+
*,
162+
config: dict[str, T.Any] | None,
163+
ctx: ToolCallApprovalContext,
164+
) -> ToolCallApprovalResult:
165+
if not config or not bool(config.get("enable", False)):
166+
return ToolCallApprovalResult(approved=True, reason="approved")
167+
168+
strategy_name = (
169+
str(config.get("strategy", "dynamic_code")).strip() or "dynamic_code"
170+
)
171+
strategy = _STRATEGY_REGISTRY.get(strategy_name)
172+
if not strategy:
173+
logger.warning("Unsupported tool call approval strategy: %s", strategy_name)
174+
return ToolCallApprovalResult(
175+
approved=False,
176+
reason="unsupported_strategy",
177+
detail=strategy_name,
178+
)
179+
return await strategy.request(ctx, config)
180+
181+
182+
async def _wait_for_code_input(
183+
*,
184+
event: AstrMessageEvent,
185+
expected_code: str,
186+
timeout: int,
187+
case_sensitive: bool,
188+
) -> ToolCallApprovalResult:
189+
session_filter = DefaultSessionFilter()
190+
FILTERS.append(session_filter)
191+
waiter = SessionWaiter(
192+
session_filter=session_filter,
193+
session_id=event.unified_msg_origin,
194+
record_history_chains=False,
195+
)
196+
197+
async def _handler(
198+
controller: SessionController, incoming: AstrMessageEvent
199+
) -> None:
200+
raw_input = (incoming.message_str or "").strip()
201+
if _is_code_match(
202+
expected=expected_code,
203+
actual=raw_input,
204+
case_sensitive=case_sensitive,
205+
):
206+
if not controller.future.done():
207+
controller.future.set_result(
208+
ToolCallApprovalResult(approved=True, reason="approved"),
209+
)
210+
else:
211+
if not controller.future.done():
212+
controller.future.set_result(
213+
ToolCallApprovalResult(
214+
approved=False,
215+
reason="rejected",
216+
detail=raw_input,
217+
)
218+
)
219+
controller.stop()
220+
221+
try:
222+
result = await waiter.register_wait(handler=_handler, timeout=timeout)
223+
except TimeoutError:
224+
return ToolCallApprovalResult(approved=False, reason="timeout")
225+
226+
if isinstance(result, ToolCallApprovalResult):
227+
return result
228+
return ToolCallApprovalResult(
229+
approved=False,
230+
reason="error",
231+
detail=f"Invalid approval result type: {type(result).__name__}",
232+
)
233+
234+
235+
def _is_code_match(*, expected: str, actual: str, case_sensitive: bool) -> bool:
236+
if case_sensitive:
237+
return actual == expected
238+
return actual.casefold() == expected.casefold()
239+
240+
241+
def _safe_int(value: T.Any, *, default: int, minimum: int) -> int:
242+
try:
243+
parsed = int(value)
244+
if parsed < minimum:
245+
return minimum
246+
return parsed
247+
except Exception: # noqa: BLE001
248+
return default

astrbot/core/astr_main_agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class MainAgentBuildConfig:
121121
timezone: str | None = None
122122
max_quoted_fallback_images: int = 20
123123
"""Maximum number of images injected from quoted-message fallback extraction."""
124+
tool_call_approval: dict = field(default_factory=dict)
125+
"""Tool call approval configuration."""
124126

125127

126128
@dataclass(slots=True)
@@ -1118,6 +1120,7 @@ async def build_main_agent(
11181120
run_context=AgentContextWrapper(
11191121
context=astr_agent_ctx,
11201122
tool_call_timeout=config.tool_call_timeout,
1123+
tool_call_approval=config.tool_call_approval,
11211124
),
11221125
tool_executor=FunctionToolExecutor(),
11231126
agent_hooks=MAIN_AGENT_HOOKS,

0 commit comments

Comments
 (0)