Skip to content

Commit 072a75d

Browse files
committed
refactor: simplify MCP init lifecycle orchestration
1 parent 7b40501 commit 072a75d

1 file changed

Lines changed: 77 additions & 111 deletions

File tree

astrbot/core/provider/func_tool_manager.py

Lines changed: 77 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class _McpClientInfo:
5858
name: str
5959
cfg: dict
6060
shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
61-
init_future: asyncio.Future[bool] | None = None
6261
task: asyncio.Task | None = None
6362

6463

@@ -294,115 +293,100 @@ async def init_mcp_clients(self) -> None:
294293
for name, cfg in mcp_server_json_obj.items():
295294
if cfg.get("active", True):
296295
info = _McpClientInfo(name=name, cfg=cfg)
297-
info.init_future = asyncio.get_running_loop().create_future()
298-
info.task = asyncio.create_task(
299-
self._init_mcp_client_task_wrapper(
300-
name,
301-
cfg,
302-
info.shutdown_event,
303-
info.init_future,
304-
),
305-
)
306296
client_infos[name] = info
307-
self.mcp_client_event[name] = info.shutdown_event
308297

309-
if client_infos:
310-
logger.info(f"等待 {len(client_infos)} 个 MCP 服务初始化...")
298+
if not client_infos:
299+
return
311300

312-
init_futures = [
313-
info.init_future
314-
for info in client_infos.values()
315-
if info.init_future is not None
316-
]
317-
_, pending_futures = await asyncio.wait(
318-
init_futures,
319-
timeout=init_timeout,
301+
logger.info(f"等待 {len(client_infos)} 个 MCP 服务初始化...")
302+
303+
init_tasks = [
304+
asyncio.create_task(
305+
self._start_mcp_client_with_timeout(
306+
name=info.name,
307+
cfg=info.cfg,
308+
shutdown_event=info.shutdown_event,
309+
timeout=init_timeout,
310+
),
311+
name=f"mcp-init:{info.name}",
320312
)
313+
for info in client_infos.values()
314+
]
315+
results = await asyncio.gather(*init_tasks, return_exceptions=True)
321316

322-
if pending_futures:
323-
logger.warning(
324-
f"MCP 服务初始化超时({timeout_display}秒),部分服务可能未完全加载。"
325-
"建议检查 MCP 服务器配置和网络连接。"
326-
)
317+
success_count = 0
318+
failed_services: list[str] = []
327319

328-
success_count = 0
329-
failed_services: list[str] = []
330-
cancelled_tasks: list[asyncio.Task] = []
331-
332-
for info in client_infos.values():
333-
if info.init_future in pending_futures:
334-
# 超时,初始化未完成,取消 task
335-
logger.error(f"MCP 服务 {info.name} 初始化超时")
336-
self._log_safe_mcp_debug_config(info.cfg)
337-
if info.task is not None:
338-
info.task.cancel()
339-
cancelled_tasks.append(info.task)
340-
failed_services.append(info.name)
341-
self.mcp_client_event.pop(info.name, None)
342-
elif info.init_future is None:
343-
logger.error(f"MCP 服务 {info.name} 初始化状态异常")
344-
failed_services.append(info.name)
345-
self.mcp_client_event.pop(info.name, None)
346-
elif info.init_future.cancelled():
347-
logger.error(f"MCP 服务 {info.name} 初始化已取消")
348-
failed_services.append(info.name)
349-
self.mcp_client_event.pop(info.name, None)
350-
elif info.init_future.exception() is not None:
351-
# 初始化期间抛出异常(已在 wrapper 中记录,此处只记录配置摘要)
352-
exc = info.init_future.exception()
353-
logger.error(f"MCP 服务 {info.name} 初始化失败: {exc}")
354-
if info.task is not None and info.task.done():
355-
info.task.exception()
356-
self._log_safe_mcp_debug_config(info.cfg)
357-
failed_services.append(info.name)
358-
self.mcp_client_event.pop(info.name, None)
320+
for info, result in zip(client_infos.values(), results, strict=False):
321+
if isinstance(result, Exception):
322+
if isinstance(result, TimeoutError):
323+
logger.error(
324+
f"MCP 服务 {info.name} 初始化超时({timeout_display}秒)"
325+
)
359326
else:
360-
success_count += 1
361-
362-
# 等待已取消的任务真正结束,避免残留后台任务和未观察到的异常
363-
if cancelled_tasks:
364-
await asyncio.gather(*cancelled_tasks, return_exceptions=True)
327+
logger.error(f"MCP 服务 {info.name} 初始化失败: {result}")
328+
self._log_safe_mcp_debug_config(info.cfg)
329+
failed_services.append(info.name)
330+
self.mcp_client_event.pop(info.name, None)
331+
continue
332+
333+
info.task = result
334+
self.mcp_client_event[info.name] = info.shutdown_event
335+
success_count += 1
336+
337+
if failed_services:
338+
logger.warning(
339+
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}。"
340+
f"请检查配置文件 mcp_server.json 和服务器可用性。"
341+
)
365342

366-
if failed_services:
367-
logger.warning(
368-
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}。"
369-
f"请检查配置文件 mcp_server.json 和服务器可用性。"
370-
)
343+
logger.info(f"MCP 服务初始化完成: {success_count}/{len(client_infos)} 成功")
371344

372-
logger.info(f"MCP 服务初始化完成: {success_count}/{len(client_infos)} 成功")
345+
async def _init_mcp_client_once(self, name: str, cfg: dict) -> None:
346+
"""仅执行一次 MCP 初始化流程,失败直接抛出异常。"""
347+
await self._init_mcp_client(name, cfg)
373348

374-
async def _init_mcp_client_task_wrapper(
349+
async def _run_mcp_client(
375350
self,
376351
name: str,
377-
cfg: dict,
378352
shutdown_event: asyncio.Event,
379-
init_future: asyncio.Future[bool] | None = None,
380353
) -> None:
381-
"""初始化 MCP 客户端的包装函数。"""
382-
initialized = False
354+
"""MCP 长生命周期任务:等待关闭信号并执行清理。"""
383355
try:
384-
await self._init_mcp_client(name, cfg)
385-
initialized = True
386-
if init_future and not init_future.done():
387-
init_future.set_result(True)
388356
await shutdown_event.wait()
389357
logger.info(f"收到 MCP 客户端 {name} 终止信号")
390358
except asyncio.CancelledError:
391-
if not initialized and init_future and not init_future.done():
392-
init_future.set_exception(
393-
asyncio.TimeoutError(f"MCP 客户端 {name} 初始化超时"),
394-
)
359+
logger.debug(f"MCP 客户端 {name} 任务被取消")
395360
raise
396-
except Exception as e:
397-
if init_future and not init_future.done():
398-
init_future.set_exception(e)
399-
if not initialized:
400-
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
401-
raise
402-
logger.error(f"MCP 客户端 {name} 初始化后运行异常: {e}", exc_info=True)
403361
finally:
404362
await self._terminate_mcp_client(name)
405363

364+
async def _start_mcp_client_with_timeout(
365+
self,
366+
name: str,
367+
cfg: dict,
368+
shutdown_event: asyncio.Event,
369+
timeout: float,
370+
) -> asyncio.Task:
371+
"""启动 MCP 客户端:先初始化,成功后再启动长生命周期任务。"""
372+
try:
373+
await asyncio.wait_for(
374+
self._init_mcp_client_once(name, cfg),
375+
timeout=timeout,
376+
)
377+
except asyncio.TimeoutError:
378+
await self._terminate_mcp_client(name)
379+
raise TimeoutError(f"MCP 服务 {name} 初始化超时({timeout:g} 秒)")
380+
except Exception as e:
381+
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
382+
await self._terminate_mcp_client(name)
383+
raise e
384+
385+
return asyncio.create_task(
386+
self._run_mcp_client(name, shutdown_event),
387+
name=f"mcp-client:{name}",
388+
)
389+
406390
async def _init_mcp_client(self, name: str, config: dict) -> None:
407391
"""初始化单个MCP客户端"""
408392
# 先清理之前的客户端,如果存在
@@ -502,30 +486,12 @@ async def enable_mcp_server(
502486
env_name=ENABLE_MCP_TIMEOUT_ENV,
503487
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
504488
)
505-
init_future = asyncio.get_running_loop().create_future()
506-
507-
init_task = asyncio.create_task(
508-
self._init_mcp_client_task_wrapper(
509-
name,
510-
config,
511-
shutdown_event,
512-
init_future,
513-
),
489+
await self._start_mcp_client_with_timeout(
490+
name=name,
491+
cfg=config,
492+
shutdown_event=shutdown_event,
493+
timeout=timeout_value,
514494
)
515-
try:
516-
await asyncio.wait_for(init_future, timeout=timeout_value)
517-
except asyncio.TimeoutError:
518-
init_task.cancel()
519-
await asyncio.gather(init_task, return_exceptions=True)
520-
raise TimeoutError(f"MCP 服务 {name} 初始化超时({timeout_value:g} 秒)")
521-
except Exception:
522-
# 消费 task 异常,避免未观察到的后台任务异常告警
523-
await asyncio.gather(init_task, return_exceptions=True)
524-
raise
525-
526-
# 如果初始化期间 task 已结束并带有异常,向上抛出
527-
if init_task.done() and init_task.exception() is not None:
528-
raise init_task.exception()
529495

530496
# 初始化成功后再注册,避免失败时暴露无效的 event
531497
self.mcp_client_event[name] = shutdown_event

0 commit comments

Comments
 (0)