Skip to content

Commit 7f39a0a

Browse files
committed
fix: refine mcp timeout handling and lifecycle task tracking
1 parent 55a64d6 commit 7f39a0a

1 file changed

Lines changed: 55 additions & 15 deletions

File tree

astrbot/core/provider/func_tool_manager.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
MAX_MCP_TIMEOUT_SECONDS = 300.0
2626

2727

28+
class MCPInitTimeoutError(asyncio.TimeoutError):
29+
"""Raised when MCP client initialization exceeds the configured timeout."""
30+
31+
2832
def _resolve_timeout(
2933
timeout: float | int | str | None = None,
3034
*,
@@ -150,6 +154,7 @@ def __init__(self) -> None:
150154
self.mcp_client_dict: dict[str, MCPClient] = {}
151155
"""MCP 服务列表"""
152156
self.mcp_client_event: dict[str, asyncio.Event] = {}
157+
self.mcp_client_task: dict[str, asyncio.Task[None]] = {}
153158

154159
def empty(self) -> bool:
155160
return len(self.func_list) == 0
@@ -316,15 +321,17 @@ async def init_mcp_clients(self) -> None:
316321
active_configs, results, strict=False
317322
):
318323
if isinstance(result, Exception):
319-
if isinstance(result, TimeoutError):
324+
if isinstance(result, MCPInitTimeoutError):
320325
logger.error(f"MCP 服务 {name} 初始化超时({timeout_display}秒)")
321326
else:
322327
logger.error(f"MCP 服务 {name} 初始化失败: {result}")
323328
self._log_safe_mcp_debug_config(cfg)
324329
failed_services.append(name)
325330
self.mcp_client_event.pop(name, None)
331+
self.mcp_client_task.pop(name, None)
326332
continue
327333

334+
self.mcp_client_task[name] = result
328335
self.mcp_client_event[name] = shutdown_event
329336
success_count += 1
330337

@@ -350,23 +357,29 @@ async def _run_mcp_client(
350357
raise
351358
finally:
352359
await self._terminate_mcp_client(name)
360+
current_task = asyncio.current_task()
361+
if self.mcp_client_task.get(name) is current_task:
362+
self.mcp_client_task.pop(name, None)
363+
self.mcp_client_event.pop(name, None)
353364

354365
async def _start_mcp_client_with_timeout(
355366
self,
356367
name: str,
357368
cfg: dict,
358369
shutdown_event: asyncio.Event,
359370
timeout: float,
360-
) -> asyncio.Task:
371+
) -> asyncio.Task[None]:
361372
"""启动 MCP 客户端:先初始化,成功后再启动长生命周期任务。"""
362373
try:
363374
await asyncio.wait_for(
364375
self._init_mcp_client(name, cfg),
365376
timeout=timeout,
366377
)
367-
except asyncio.TimeoutError:
378+
except asyncio.TimeoutError as exc:
368379
await self._terminate_mcp_client(name)
369-
raise TimeoutError(f"MCP 服务 {name} 初始化超时({timeout:g} 秒)")
380+
raise MCPInitTimeoutError(
381+
f"MCP 服务 {name} 初始化超时({timeout:g} 秒)"
382+
) from exc
370383
except Exception:
371384
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
372385
await self._terminate_mcp_client(name)
@@ -463,7 +476,7 @@ async def enable_mcp_server(
463476
timeout: Timeout in seconds for initialization.
464477
465478
Raises:
466-
TimeoutError: If initialization does not complete within the timeout.
479+
MCPInitTimeoutError: If initialization does not complete within timeout.
467480
Exception: If there is an error during initialization.
468481
"""
469482
if name in self.mcp_client_dict:
@@ -476,14 +489,15 @@ async def enable_mcp_server(
476489
env_name=ENABLE_MCP_TIMEOUT_ENV,
477490
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
478491
)
479-
await self._start_mcp_client_with_timeout(
492+
lifecycle_task = await self._start_mcp_client_with_timeout(
480493
name=name,
481494
cfg=config,
482495
shutdown_event=shutdown_event,
483496
timeout=timeout_value,
484497
)
485498

486499
# 初始化成功后再注册,避免失败时暴露无效的 event
500+
self.mcp_client_task[name] = lifecycle_task
487501
self.mcp_client_event[name] = shutdown_event
488502

489503
async def disable_mcp_server(
@@ -499,17 +513,31 @@ async def disable_mcp_server(
499513
500514
"""
501515
if name:
502-
if name not in self.mcp_client_event:
503-
return
504-
client = self.mcp_client_dict.get(name)
505-
self.mcp_client_event[name].set()
506-
if not client:
516+
event = self.mcp_client_event.get(name)
517+
task = self.mcp_client_task.get(name)
518+
if event is None and task is None and name not in self.mcp_client_dict:
507519
return
508-
client_running_event = client.running_event
520+
521+
if event:
522+
event.set()
523+
509524
try:
510-
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
525+
if task is not None:
526+
await asyncio.wait_for(asyncio.shield(task), timeout=timeout)
527+
else:
528+
client = self.mcp_client_dict.get(name)
529+
if client is not None:
530+
await asyncio.wait_for(
531+
client.running_event.wait(), timeout=timeout
532+
)
533+
except asyncio.TimeoutError:
534+
if task is not None and not task.done():
535+
task.cancel()
536+
await asyncio.gather(task, return_exceptions=True)
537+
raise
511538
finally:
512539
self.mcp_client_event.pop(name, None)
540+
self.mcp_client_task.pop(name, None)
513541
self.func_list = [
514542
f
515543
for f in self.func_list
@@ -519,13 +547,25 @@ async def disable_mcp_server(
519547
running_events = [
520548
client.running_event.wait() for client in self.mcp_client_dict.values()
521549
]
522-
for key, event in self.mcp_client_event.items():
550+
lifecycle_tasks = list(self.mcp_client_task.values())
551+
for _, event in list(self.mcp_client_event.items()):
523552
event.set()
524553
# waiting for all clients to finish
525554
try:
526-
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
555+
await asyncio.wait_for(
556+
asyncio.gather(*running_events, *lifecycle_tasks),
557+
timeout=timeout,
558+
)
559+
except asyncio.TimeoutError:
560+
for task in lifecycle_tasks:
561+
if not task.done():
562+
task.cancel()
563+
if lifecycle_tasks:
564+
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
565+
raise
527566
finally:
528567
self.mcp_client_event.clear()
568+
self.mcp_client_task.clear()
529569
self.mcp_client_dict.clear()
530570
self.func_list = [
531571
f for f in self.func_list if not isinstance(f, MCPTool)

0 commit comments

Comments
 (0)