Skip to content

Commit 88274d9

Browse files
committed
refactor: simplify mcp runtime registry and timeout flow
1 parent 807fb4a commit 88274d9

1 file changed

Lines changed: 74 additions & 72 deletions

File tree

astrbot/core/provider/func_tool_manager.py

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import urllib.parse
88
from collections.abc import AsyncGenerator, Awaitable, Callable
9+
from dataclasses import dataclass
910
from typing import Any
1011

1112
import aiohttp
@@ -29,12 +30,26 @@ class MCPInitTimeoutError(asyncio.TimeoutError):
2930
"""Raised when MCP client initialization exceeds the configured timeout."""
3031

3132

33+
@dataclass
34+
class MCPInitSummary:
35+
total: int
36+
success: int
37+
failed: list[str]
38+
39+
40+
@dataclass
41+
class _MCPServerRuntime:
42+
shutdown_event: asyncio.Event
43+
lifecycle_task: asyncio.Task[None]
44+
45+
3246
def _resolve_timeout(
3347
timeout: float | int | str | None = None,
3448
*,
3549
env_name: str = MCP_INIT_TIMEOUT_ENV,
3650
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
3751
) -> float:
52+
"""Resolve timeout with precedence: explicit argument > env value > default."""
3853
source = f"环境变量 {env_name}"
3954
if timeout is None:
4055
timeout = os.getenv(env_name, str(default))
@@ -156,8 +171,7 @@ def __init__(self) -> None:
156171
self.func_list: list[FuncTool] = []
157172
self.mcp_client_dict: dict[str, MCPClient] = {}
158173
"""MCP 服务列表"""
159-
self.mcp_client_event: dict[str, asyncio.Event] = {}
160-
self.mcp_client_task: dict[str, asyncio.Task[None]] = {}
174+
self.mcp_server_runtime: dict[str, _MCPServerRuntime] = {}
161175

162176
def empty(self) -> bool:
163177
return len(self.func_list) == 0
@@ -253,7 +267,7 @@ def _log_safe_mcp_debug_config(cfg: dict) -> None:
253267
port = ""
254268
logger.debug(f" 主机: {scheme}://{host}{port}")
255269

256-
async def init_mcp_clients(self) -> None:
270+
async def init_mcp_clients(self) -> MCPInitSummary:
257271
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
258272
```
259273
{
@@ -299,7 +313,7 @@ async def init_mcp_clients(self) -> None:
299313
active_configs.append((name, cfg, shutdown_event))
300314

301315
if not active_configs:
302-
return
316+
return MCPInitSummary(total=0, success=0, failed=[])
303317

304318
logger.info(f"等待 {len(active_configs)} 个 MCP 服务初始化...")
305319

@@ -308,8 +322,6 @@ async def init_mcp_clients(self) -> None:
308322
self._start_mcp_server(
309323
name=name,
310324
cfg=cfg,
311-
timeout_env=MCP_INIT_TIMEOUT_ENV,
312-
default_timeout=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
313325
shutdown_event=shutdown_event,
314326
timeout=init_timeout,
315327
),
@@ -330,8 +342,7 @@ async def init_mcp_clients(self) -> None:
330342
logger.error(f"MCP 服务 {name} 初始化失败: {result}")
331343
self._log_safe_mcp_debug_config(cfg)
332344
failed_services.append(name)
333-
self.mcp_client_event.pop(name, None)
334-
self.mcp_client_task.pop(name, None)
345+
self.mcp_server_runtime.pop(name, None)
335346
continue
336347

337348
success_count += 1
@@ -342,7 +353,15 @@ async def init_mcp_clients(self) -> None:
342353
f"请检查配置文件 mcp_server.json 和服务器可用性。"
343354
)
344355

345-
logger.info(f"MCP 服务初始化完成: {success_count}/{len(active_configs)} 成功")
356+
summary = MCPInitSummary(
357+
total=len(active_configs), success=success_count, failed=failed_services
358+
)
359+
logger.info(f"MCP 服务初始化完成: {summary.success}/{summary.total} 成功")
360+
if summary.total > 0 and summary.success == 0:
361+
raise RuntimeError(
362+
"全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
363+
)
364+
return summary
346365

347366
async def _run_mcp_client(
348367
self,
@@ -359,9 +378,9 @@ async def _run_mcp_client(
359378
finally:
360379
await self._terminate_mcp_client(name)
361380
current_task = asyncio.current_task()
362-
if self.mcp_client_task.get(name) is current_task:
363-
self.mcp_client_task.pop(name, None)
364-
self.mcp_client_event.pop(name, None)
381+
runtime = self.mcp_server_runtime.get(name)
382+
if runtime and runtime.lifecycle_task is current_task:
383+
self.mcp_server_runtime.pop(name, None)
365384

366385
async def _start_mcp_client_with_timeout(
367386
self,
@@ -396,10 +415,8 @@ async def _start_mcp_server(
396415
name: str,
397416
cfg: dict,
398417
*,
399-
timeout_env: str,
400-
default_timeout: float,
401418
shutdown_event: asyncio.Event | None = None,
402-
timeout: float | int | str | None = None,
419+
timeout: float,
403420
) -> None:
404421
"""Initialize MCP server with timeout and register task/event together."""
405422
if name in self.mcp_client_dict:
@@ -408,36 +425,28 @@ async def _start_mcp_server(
408425
if shutdown_event is None:
409426
shutdown_event = asyncio.Event()
410427

411-
timeout_value = _resolve_timeout(
412-
timeout=timeout,
413-
env_name=timeout_env,
414-
default=default_timeout,
415-
)
416428
lifecycle_task = await self._start_mcp_client_with_timeout(
417429
name=name,
418430
cfg=cfg,
419431
shutdown_event=shutdown_event,
420-
timeout=timeout_value,
432+
timeout=timeout,
433+
)
434+
self.mcp_server_runtime[name] = _MCPServerRuntime(
435+
shutdown_event=shutdown_event,
436+
lifecycle_task=lifecycle_task,
421437
)
422-
self.mcp_client_task[name] = lifecycle_task
423-
self.mcp_client_event[name] = shutdown_event
424438

425-
async def _wait_mcp_lifecycle_task(self, name: str, timeout: float) -> None:
439+
async def _wait_mcp_lifecycle_task(
440+
self, lifecycle_task: asyncio.Task[None], timeout: float
441+
) -> None:
426442
"""Wait for lifecycle task first; fallback to client running_event."""
427-
task = self.mcp_client_task.get(name)
428-
if task is not None:
429-
try:
430-
await asyncio.wait_for(asyncio.shield(task), timeout=timeout)
431-
except asyncio.TimeoutError:
432-
if not task.done():
433-
task.cancel()
434-
await asyncio.gather(task, return_exceptions=True)
435-
raise
436-
return
437-
438-
client = self.mcp_client_dict.get(name)
439-
if client is not None:
440-
await asyncio.wait_for(client.running_event.wait(), timeout=timeout)
443+
try:
444+
await asyncio.wait_for(asyncio.shield(lifecycle_task), timeout=timeout)
445+
except asyncio.TimeoutError:
446+
if not lifecycle_task.done():
447+
lifecycle_task.cancel()
448+
await asyncio.gather(lifecycle_task, return_exceptions=True)
449+
raise
441450

442451
async def _init_mcp_client(self, name: str, config: dict) -> None:
443452
"""初始化单个MCP客户端"""
@@ -529,14 +538,18 @@ async def enable_mcp_server(
529538
Exception: If there is an error during initialization.
530539
"""
531540
if name in self.mcp_client_dict:
541+
logger.info(f"MCP 服务 {name} 已存在,跳过重复启用。")
532542
return
543+
timeout_value = _resolve_timeout(
544+
timeout=timeout,
545+
env_name=ENABLE_MCP_TIMEOUT_ENV,
546+
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
547+
)
533548
await self._start_mcp_server(
534549
name=name,
535550
cfg=config,
536-
timeout_env=ENABLE_MCP_TIMEOUT_ENV,
537-
default_timeout=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
538551
shutdown_event=shutdown_event,
539-
timeout=timeout,
552+
timeout=timeout_value,
540553
)
541554

542555
async def disable_mcp_server(
@@ -552,49 +565,39 @@ async def disable_mcp_server(
552565
553566
"""
554567
if name:
555-
event = self.mcp_client_event.get(name)
556-
task = self.mcp_client_task.get(name)
557-
if event is None and task is None and name not in self.mcp_client_dict:
568+
runtime = self.mcp_server_runtime.get(name)
569+
if runtime is None and name not in self.mcp_client_dict:
558570
return
559571

560-
if event is not None:
561-
event.set()
562-
elif task is not None:
563-
logger.warning(
564-
f"MCP 服务 {name} 缺少 shutdown event,直接取消生命周期任务。"
565-
)
566-
if not task.done():
567-
task.cancel()
568-
await asyncio.gather(task, return_exceptions=True)
569-
570572
try:
571-
if event is not None or task is None:
572-
await self._wait_mcp_lifecycle_task(name, timeout)
573+
if runtime is not None:
574+
runtime.shutdown_event.set()
575+
await self._wait_mcp_lifecycle_task(runtime.lifecycle_task, timeout)
576+
else:
577+
client = self.mcp_client_dict.get(name)
578+
if client is not None:
579+
await asyncio.wait_for(
580+
client.running_event.wait(), timeout=timeout
581+
)
573582
finally:
574-
self.mcp_client_event.pop(name, None)
575-
self.mcp_client_task.pop(name, None)
583+
self.mcp_server_runtime.pop(name, None)
576584
self.func_list = [
577585
f
578586
for f in self.func_list
579587
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
580588
]
581589
else:
582-
for _, event in list(self.mcp_client_event.items()):
583-
event.set()
584-
585-
lifecycle_tasks: list[asyncio.Task[None]] = []
586-
for task_name, task in list(self.mcp_client_task.items()):
587-
if task_name not in self.mcp_client_event and not task.done():
588-
logger.warning(
589-
f"MCP 服务 {task_name} 缺少 shutdown event,直接取消生命周期任务。"
590-
)
591-
task.cancel()
592-
lifecycle_tasks.append(task)
590+
for runtime in self.mcp_server_runtime.values():
591+
runtime.shutdown_event.set()
592+
593+
lifecycle_tasks: list[asyncio.Task[None]] = [
594+
runtime.lifecycle_task for runtime in self.mcp_server_runtime.values()
595+
]
593596

594597
running_events = [
595598
client.running_event.wait()
596599
for client_name, client in self.mcp_client_dict.items()
597-
if client_name not in self.mcp_client_task
600+
if client_name not in self.mcp_server_runtime
598601
]
599602
# waiting for all clients to finish
600603
try:
@@ -610,8 +613,7 @@ async def disable_mcp_server(
610613
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
611614
raise
612615
finally:
613-
self.mcp_client_event.clear()
614-
self.mcp_client_task.clear()
616+
self.mcp_server_runtime.clear()
615617
self.mcp_client_dict.clear()
616618
self.func_list = [
617619
f for f in self.func_list if not isinstance(f, MCPTool)

0 commit comments

Comments
 (0)