Skip to content

Commit 09e6174

Browse files
committed
refactor: streamline mcp lifecycle and init errors
1 parent a8e0b9d commit 09e6174

1 file changed

Lines changed: 46 additions & 58 deletions

File tree

astrbot/core/provider/func_tool_manager.py

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,18 @@
2626
MAX_MCP_TIMEOUT_SECONDS = 300.0
2727

2828

29-
class MCPInitTimeoutError(asyncio.TimeoutError):
29+
class MCPInitError(Exception):
30+
"""Base exception for MCP initialization failures."""
31+
32+
33+
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
3034
"""Raised when MCP client initialization exceeds the configured timeout."""
3135

3236

37+
class MCPAllServicesFailedError(MCPInitError):
38+
"""Raised when all configured MCP services fail to initialize."""
39+
40+
3341
@dataclass
3442
class MCPInitSummary:
3543
total: int
@@ -80,6 +88,20 @@ def _resolve_timeout(
8088
return timeout_value
8189

8290

91+
def _resolve_mcp_timeout(
92+
*,
93+
timeout: float | int | str | None = None,
94+
init_phase: bool,
95+
) -> float:
96+
env_name = MCP_INIT_TIMEOUT_ENV if init_phase else ENABLE_MCP_TIMEOUT_ENV
97+
default = (
98+
DEFAULT_MCP_INIT_TIMEOUT_SECONDS
99+
if init_phase
100+
else DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS
101+
)
102+
return _resolve_timeout(timeout=timeout, env_name=env_name, default=default)
103+
104+
83105
SUPPORTED_TYPES = [
84106
"string",
85107
"number",
@@ -300,10 +322,7 @@ async def init_mcp_clients(self) -> MCPInitSummary:
300322
open(mcp_json_file, encoding="utf-8"),
301323
)["mcpServers"]
302324

303-
init_timeout = _resolve_timeout(
304-
env_name=MCP_INIT_TIMEOUT_ENV,
305-
default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
306-
)
325+
init_timeout = _resolve_mcp_timeout(init_phase=True)
307326
timeout_display = f"{init_timeout:g}"
308327

309328
active_configs: list[tuple[str, dict, asyncio.Event]] = []
@@ -358,7 +377,7 @@ async def init_mcp_clients(self) -> MCPInitSummary:
358377
)
359378
logger.info(f"MCP 服务初始化完成: {summary.success}/{summary.total} 成功")
360379
if summary.total > 0 and summary.success == 0:
361-
raise RuntimeError(
380+
raise MCPAllServicesFailedError(
362381
"全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
363382
)
364383
return summary
@@ -382,14 +401,21 @@ async def _run_mcp_client(
382401
if runtime and runtime.lifecycle_task is current_task:
383402
self.mcp_server_runtime.pop(name, None)
384403

385-
async def _start_mcp_client_with_timeout(
404+
async def _start_mcp_server(
386405
self,
387406
name: str,
388407
cfg: dict,
389-
shutdown_event: asyncio.Event,
408+
*,
409+
shutdown_event: asyncio.Event | None = None,
390410
timeout: float,
391-
) -> asyncio.Task[None]:
392-
"""启动 MCP 客户端:先初始化,成功后再启动长生命周期任务。"""
411+
) -> None:
412+
"""Initialize MCP server with timeout and register task/event together."""
413+
if name in self.mcp_server_runtime:
414+
return
415+
416+
if shutdown_event is None:
417+
shutdown_event = asyncio.Event()
418+
393419
try:
394420
await asyncio.wait_for(
395421
self._init_mcp_client(name, cfg),
@@ -405,32 +431,10 @@ async def _start_mcp_client_with_timeout(
405431
await self._terminate_mcp_client(name)
406432
raise
407433

408-
return asyncio.create_task(
434+
lifecycle_task = asyncio.create_task(
409435
self._run_mcp_client(name, shutdown_event),
410436
name=f"mcp-client:{name}",
411437
)
412-
413-
async def _start_mcp_server(
414-
self,
415-
name: str,
416-
cfg: dict,
417-
*,
418-
shutdown_event: asyncio.Event | None = None,
419-
timeout: float,
420-
) -> None:
421-
"""Initialize MCP server with timeout and register task/event together."""
422-
if name in self.mcp_client_dict:
423-
return
424-
425-
if shutdown_event is None:
426-
shutdown_event = asyncio.Event()
427-
428-
lifecycle_task = await self._start_mcp_client_with_timeout(
429-
name=name,
430-
cfg=cfg,
431-
shutdown_event=shutdown_event,
432-
timeout=timeout,
433-
)
434438
self.mcp_server_runtime[name] = _MCPServerRuntime(
435439
shutdown_event=shutdown_event,
436440
lifecycle_task=lifecycle_task,
@@ -537,14 +541,10 @@ async def enable_mcp_server(
537541
MCPInitTimeoutError: If initialization does not complete within timeout.
538542
Exception: If there is an error during initialization.
539543
"""
540-
if name in self.mcp_client_dict:
544+
if name in self.mcp_server_runtime:
541545
logger.info(f"MCP 服务 {name} 已存在,跳过重复启用。")
542546
return
543-
timeout_value = _resolve_timeout(
544-
timeout=timeout,
545-
env_name=ENABLE_MCP_TIMEOUT_ENV,
546-
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
547-
)
547+
timeout_value = _resolve_mcp_timeout(timeout=timeout, init_phase=False)
548548
await self._start_mcp_server(
549549
name=name,
550550
cfg=config,
@@ -566,19 +566,12 @@ async def disable_mcp_server(
566566
"""
567567
if name:
568568
runtime = self.mcp_server_runtime.get(name)
569-
if runtime is None and name not in self.mcp_client_dict:
569+
if runtime is None:
570570
return
571571

572572
try:
573-
if runtime is not None:
574-
runtime.shutdown_event.set()
575-
await self._wait_mcp_lifecycle_task(runtime.lifecycle_task, timeout)
576-
else:
577-
client = self.mcp_client_dict.get(name)
578-
if client is not None:
579-
await asyncio.wait_for(
580-
client.running_event.wait(), timeout=timeout
581-
)
573+
runtime.shutdown_event.set()
574+
await self._wait_mcp_lifecycle_task(runtime.lifecycle_task, timeout)
582575
finally:
583576
self.mcp_server_runtime.pop(name, None)
584577
self.func_list = [
@@ -587,22 +580,17 @@ async def disable_mcp_server(
587580
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
588581
]
589582
else:
590-
for runtime in self.mcp_server_runtime.values():
583+
runtimes = list(self.mcp_server_runtime.values())
584+
for runtime in runtimes:
591585
runtime.shutdown_event.set()
592586

593587
lifecycle_tasks: list[asyncio.Task[None]] = [
594-
runtime.lifecycle_task for runtime in self.mcp_server_runtime.values()
595-
]
596-
597-
running_events = [
598-
client.running_event.wait()
599-
for client_name, client in self.mcp_client_dict.items()
600-
if client_name not in self.mcp_server_runtime
588+
runtime.lifecycle_task for runtime in runtimes
601589
]
602590
# waiting for all clients to finish
603591
try:
604592
await asyncio.wait_for(
605-
asyncio.gather(*running_events, *lifecycle_tasks),
593+
asyncio.gather(*lifecycle_tasks),
606594
timeout=timeout,
607595
)
608596
except asyncio.TimeoutError:

0 commit comments

Comments
 (0)