Skip to content

Commit a058bef

Browse files
committed
fix: harden MCP init state handling and timeout parsing
1 parent a16d05c commit a058bef

1 file changed

Lines changed: 131 additions & 60 deletions

File tree

astrbot/core/provider/func_tool_manager.py

Lines changed: 131 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,36 @@
1919

2020
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
2121

22-
# MCP 服务初始化的默认超时时间(秒)。
23-
# 可通过环境变量 ASTRBOT_MCP_INIT_TIMEOUT 覆盖,适配慢速或远程 MCP 服务器。
24-
MCP_INIT_TIMEOUT: float = float(os.getenv("ASTRBOT_MCP_INIT_TIMEOUT", "20.0"))
22+
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 20.0
23+
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 30.0
24+
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
25+
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
26+
27+
28+
def _resolve_timeout(
29+
timeout: float | int | str | None = None,
30+
*,
31+
env_name: str = MCP_INIT_TIMEOUT_ENV,
32+
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
33+
) -> float:
34+
if timeout is None:
35+
timeout = os.getenv(env_name, str(default))
36+
37+
try:
38+
timeout_value = float(timeout)
39+
except (TypeError, ValueError):
40+
logger.warning(
41+
f"超时配置 {env_name}={timeout!r} 无效,使用默认值 {default:g} 秒。"
42+
)
43+
return default
44+
45+
if timeout_value <= 0:
46+
logger.warning(
47+
f"超时配置 {env_name}={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
48+
)
49+
return default
50+
51+
return timeout_value
2552

2653

2754
@dataclass
@@ -31,7 +58,7 @@ class _McpClientInfo:
3158
name: str
3259
cfg: dict
3360
shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
34-
ready_event: asyncio.Event = field(default_factory=asyncio.Event)
61+
init_future: asyncio.Future[bool] | None = None
3562
task: asyncio.Task | None = None
3663

3764

@@ -197,6 +224,31 @@ def get_full_tool_set(self) -> ToolSet:
197224
tool_set = ToolSet(self.func_list.copy())
198225
return tool_set
199226

227+
@staticmethod
228+
def _log_safe_mcp_debug_config(cfg: dict) -> None:
229+
# 仅记录脱敏后的摘要,避免泄露 command/args/url 中的敏感信息
230+
if "command" in cfg:
231+
cmd = cfg["command"]
232+
executable = str(cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd)
233+
args_val = cfg.get("args", [])
234+
args_count = (
235+
len(args_val)
236+
if isinstance(args_val, (list, tuple))
237+
else (0 if args_val is None else 1)
238+
)
239+
logger.debug(f" 命令可执行文件: {executable}, 参数数量: {args_count}")
240+
return
241+
242+
if "url" in cfg:
243+
parsed = urllib.parse.urlparse(str(cfg["url"]))
244+
host = parsed.hostname or ""
245+
scheme = parsed.scheme or "unknown"
246+
try:
247+
port = f":{parsed.port}" if parsed.port else ""
248+
except ValueError:
249+
port = ""
250+
logger.debug(f" 主机: {scheme}://{host}{port}")
251+
200252
async def init_mcp_clients(self) -> None:
201253
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
202254
```
@@ -230,15 +282,25 @@ async def init_mcp_clients(self) -> None:
230282
open(mcp_json_file, encoding="utf-8"),
231283
)["mcpServers"]
232284

285+
init_timeout = _resolve_timeout(
286+
env_name=MCP_INIT_TIMEOUT_ENV,
287+
default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
288+
)
289+
timeout_display = f"{init_timeout:g}"
290+
233291
# 用 _McpClientInfo 跟踪每个客户端的事件和任务
234292
client_infos: dict[str, _McpClientInfo] = {}
235293

236294
for name, cfg in mcp_server_json_obj.items():
237295
if cfg.get("active", True):
238296
info = _McpClientInfo(name=name, cfg=cfg)
297+
info.init_future = asyncio.get_running_loop().create_future()
239298
info.task = asyncio.create_task(
240299
self._init_mcp_client_task_wrapper(
241-
name, cfg, info.shutdown_event, info.ready_event
300+
name,
301+
cfg,
302+
info.shutdown_event,
303+
info.init_future,
242304
),
243305
)
244306
client_infos[name] = info
@@ -247,59 +309,52 @@ async def init_mcp_clients(self) -> None:
247309
if client_infos:
248310
logger.info(f"等待 {len(client_infos)} 个 MCP 服务初始化...")
249311

250-
# 只等待初始化完成信号,不等待整个 task 的生命周期结束
251-
wait_tasks = [
252-
asyncio.create_task(info.ready_event.wait())
312+
init_futures = [
313+
info.init_future
253314
for info in client_infos.values()
315+
if info.init_future is not None
254316
]
255-
try:
256-
await asyncio.wait_for(
257-
asyncio.gather(*wait_tasks, return_exceptions=True),
258-
timeout=MCP_INIT_TIMEOUT,
317+
_, pending_futures = await asyncio.wait(
318+
init_futures,
319+
timeout=init_timeout,
320+
)
321+
322+
if pending_futures:
323+
logger.warning(
324+
f"MCP 服务初始化超时({timeout_display}秒),部分服务可能未完全加载。"
325+
"建议检查 MCP 服务器配置和网络连接。"
259326
)
260-
except asyncio.TimeoutError:
261-
# 取消尚未完成的 wait_tasks,避免悬挂
262-
for t in wait_tasks:
263-
t.cancel()
264-
await asyncio.gather(*wait_tasks, return_exceptions=True)
265327

266328
success_count = 0
267329
failed_services: list[str] = []
268330
cancelled_tasks: list[asyncio.Task] = []
269331

270332
for info in client_infos.values():
271-
if not info.ready_event.is_set():
333+
if info.init_future in pending_futures:
272334
# 超时,初始化未完成,取消 task
273335
logger.error(f"MCP 服务 {info.name} 初始化超时")
274-
info.task.cancel()
275-
cancelled_tasks.append(info.task)
336+
if info.task is not None:
337+
info.task.cancel()
338+
cancelled_tasks.append(info.task)
339+
failed_services.append(info.name)
340+
self.mcp_client_event.pop(info.name, None)
341+
elif info.init_future is None:
342+
logger.error(f"MCP 服务 {info.name} 初始化状态异常")
276343
failed_services.append(info.name)
277-
elif info.task.done() and info.task.exception() is not None:
344+
self.mcp_client_event.pop(info.name, None)
345+
elif info.init_future.cancelled():
346+
logger.error(f"MCP 服务 {info.name} 初始化已取消")
347+
failed_services.append(info.name)
348+
self.mcp_client_event.pop(info.name, None)
349+
elif info.init_future.exception() is not None:
278350
# 初始化期间抛出异常(已在 wrapper 中记录,此处只记录配置摘要)
279-
exc = info.task.exception()
351+
exc = info.init_future.exception()
280352
logger.error(f"MCP 服务 {info.name} 初始化失败: {exc}")
281-
cfg = info.cfg
282-
if "command" in cfg:
283-
cmd = cfg["command"]
284-
executable = str(
285-
cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd
286-
)
287-
args_val = cfg.get("args", [])
288-
args_count = (
289-
len(args_val)
290-
if isinstance(args_val, (list, tuple))
291-
else (0 if args_val is None else 1)
292-
)
293-
logger.debug(
294-
f" 命令可执行文件: {executable}, 参数数量: {args_count}"
295-
)
296-
elif "url" in cfg:
297-
parsed = urllib.parse.urlparse(cfg["url"])
298-
# 只记录 scheme + hostname + port,不记录 userinfo 和查询参数
299-
host = parsed.hostname or ""
300-
port = f":{parsed.port}" if parsed.port else ""
301-
logger.debug(f" 主机: {parsed.scheme}://{host}{port}")
353+
if info.task is not None and info.task.done():
354+
info.task.exception()
355+
self._log_safe_mcp_debug_config(info.cfg)
302356
failed_services.append(info.name)
357+
self.mcp_client_event.pop(info.name, None)
303358
else:
304359
success_count += 1
305360

@@ -320,25 +375,29 @@ async def _init_mcp_client_task_wrapper(
320375
name: str,
321376
cfg: dict,
322377
shutdown_event: asyncio.Event,
323-
ready_event: asyncio.Event,
378+
init_future: asyncio.Future[bool] | None = None,
324379
) -> None:
325-
"""初始化 MCP 客户端的包装函数。
326-
327-
初始化完成后立即 set ready_event,让 init_mcp_clients 可以
328-
及时返回,而无需等待整个客户端的生命周期结束。
329-
"""
380+
"""初始化 MCP 客户端的包装函数。"""
381+
initialized = False
330382
try:
331383
await self._init_mcp_client(name, cfg)
332-
ready_event.set()
384+
initialized = True
385+
if init_future and not init_future.done():
386+
init_future.set_result(True)
333387
await shutdown_event.wait()
334388
logger.info(f"收到 MCP 客户端 {name} 终止信号")
335389
except asyncio.CancelledError:
336-
ready_event.set() # 确保取消时也能解除等待
337-
raise
338-
except Exception:
339-
ready_event.set() # 确保初始化失败时也能解除等待
340-
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
390+
if not initialized and init_future and not init_future.done():
391+
init_future.set_exception(
392+
asyncio.TimeoutError(f"MCP 客户端 {name} 初始化超时"),
393+
)
341394
raise
395+
except Exception as e:
396+
if init_future and not init_future.done():
397+
init_future.set_exception(e)
398+
if not initialized:
399+
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
400+
raise
342401
finally:
343402
await self._terminate_mcp_client(name)
344403

@@ -417,7 +476,7 @@ async def enable_mcp_server(
417476
name: str,
418477
config: dict,
419478
shutdown_event: asyncio.Event | None = None,
420-
timeout: float = MCP_INIT_TIMEOUT,
479+
timeout: float | int | str | None = None,
421480
) -> None:
422481
"""Enable a new MCP server and initialize it.
423482
@@ -436,19 +495,31 @@ async def enable_mcp_server(
436495
if not shutdown_event:
437496
shutdown_event = asyncio.Event()
438497

439-
ready_event = asyncio.Event()
498+
timeout_value = _resolve_timeout(
499+
timeout=timeout,
500+
env_name=ENABLE_MCP_TIMEOUT_ENV,
501+
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
502+
)
503+
init_future = asyncio.get_running_loop().create_future()
440504

441505
init_task = asyncio.create_task(
442506
self._init_mcp_client_task_wrapper(
443-
name, config, shutdown_event, ready_event
507+
name,
508+
config,
509+
shutdown_event,
510+
init_future,
444511
),
445512
)
446513
try:
447-
await asyncio.wait_for(ready_event.wait(), timeout=timeout)
514+
await asyncio.wait_for(init_future, timeout=timeout_value)
448515
except asyncio.TimeoutError:
449516
init_task.cancel()
450517
await asyncio.gather(init_task, return_exceptions=True)
451-
raise TimeoutError(f"MCP 服务 {name} 初始化超时({timeout} 秒)")
518+
raise TimeoutError(f"MCP 服务 {name} 初始化超时({timeout_value:g} 秒)")
519+
except Exception:
520+
# 消费 task 异常,避免未观察到的后台任务异常告警
521+
await asyncio.gather(init_task, return_exceptions=True)
522+
raise
452523

453524
# 如果初始化期间 task 已结束并带有异常,向上抛出
454525
if init_task.done() and init_task.exception() is not None:

0 commit comments

Comments
 (0)