Skip to content

Commit 807fb4a

Browse files
committed
fix: harden mcp shutdown and timeout source logging
1 parent 7f39a0a commit 807fb4a

1 file changed

Lines changed: 87 additions & 41 deletions

File tree

astrbot/core/provider/func_tool_manager.py

Lines changed: 87 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,29 @@ def _resolve_timeout(
3535
env_name: str = MCP_INIT_TIMEOUT_ENV,
3636
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
3737
) -> float:
38+
source = f"环境变量 {env_name}"
3839
if timeout is None:
3940
timeout = os.getenv(env_name, str(default))
41+
else:
42+
source = "显式参数 timeout"
4043

4144
try:
4245
timeout_value = float(timeout)
4346
except (TypeError, ValueError):
4447
logger.warning(
45-
f"超时配置 {env_name}={timeout!r} 无效,使用默认值 {default:g} 秒。"
48+
f"超时配置{source}={timeout!r} 无效,使用默认值 {default:g} 秒。"
4649
)
4750
return default
4851

4952
if timeout_value <= 0:
5053
logger.warning(
51-
f"超时配置 {env_name}={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
54+
f"超时配置{source}={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
5255
)
5356
return default
5457

5558
if timeout_value > MAX_MCP_TIMEOUT_SECONDS:
5659
logger.warning(
57-
f"超时配置 {env_name}={timeout_value:g} 过大,已限制为最大值 "
60+
f"超时配置{source}={timeout_value:g} 过大,已限制为最大值 "
5861
f"{MAX_MCP_TIMEOUT_SECONDS:g} 秒,以避免长时间等待。"
5962
)
6063
return MAX_MCP_TIMEOUT_SECONDS
@@ -302,9 +305,11 @@ async def init_mcp_clients(self) -> None:
302305

303306
init_tasks = [
304307
asyncio.create_task(
305-
self._start_mcp_client_with_timeout(
308+
self._start_mcp_server(
306309
name=name,
307310
cfg=cfg,
311+
timeout_env=MCP_INIT_TIMEOUT_ENV,
312+
default_timeout=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
308313
shutdown_event=shutdown_event,
309314
timeout=init_timeout,
310315
),
@@ -317,9 +322,7 @@ async def init_mcp_clients(self) -> None:
317322
success_count = 0
318323
failed_services: list[str] = []
319324

320-
for (name, cfg, shutdown_event), result in zip(
321-
active_configs, results, strict=False
322-
):
325+
for (name, cfg, _), result in zip(active_configs, results, strict=False):
323326
if isinstance(result, Exception):
324327
if isinstance(result, MCPInitTimeoutError):
325328
logger.error(f"MCP 服务 {name} 初始化超时({timeout_display}秒)")
@@ -331,8 +334,6 @@ async def init_mcp_clients(self) -> None:
331334
self.mcp_client_task.pop(name, None)
332335
continue
333336

334-
self.mcp_client_task[name] = result
335-
self.mcp_client_event[name] = shutdown_event
336337
success_count += 1
337338

338339
if failed_services:
@@ -390,6 +391,54 @@ async def _start_mcp_client_with_timeout(
390391
name=f"mcp-client:{name}",
391392
)
392393

394+
async def _start_mcp_server(
395+
self,
396+
name: str,
397+
cfg: dict,
398+
*,
399+
timeout_env: str,
400+
default_timeout: float,
401+
shutdown_event: asyncio.Event | None = None,
402+
timeout: float | int | str | None = None,
403+
) -> None:
404+
"""Initialize MCP server with timeout and register task/event together."""
405+
if name in self.mcp_client_dict:
406+
return
407+
408+
if shutdown_event is None:
409+
shutdown_event = asyncio.Event()
410+
411+
timeout_value = _resolve_timeout(
412+
timeout=timeout,
413+
env_name=timeout_env,
414+
default=default_timeout,
415+
)
416+
lifecycle_task = await self._start_mcp_client_with_timeout(
417+
name=name,
418+
cfg=cfg,
419+
shutdown_event=shutdown_event,
420+
timeout=timeout_value,
421+
)
422+
self.mcp_client_task[name] = lifecycle_task
423+
self.mcp_client_event[name] = shutdown_event
424+
425+
async def _wait_mcp_lifecycle_task(self, name: str, timeout: float) -> None:
426+
"""Wait for lifecycle task first; fallback to client running_event."""
427+
task = self.mcp_client_task.get(name)
428+
if task is not None:
429+
try:
430+
await asyncio.wait_for(asyncio.shield(task), timeout=timeout)
431+
except asyncio.TimeoutError:
432+
if not task.done():
433+
task.cancel()
434+
await asyncio.gather(task, return_exceptions=True)
435+
raise
436+
return
437+
438+
client = self.mcp_client_dict.get(name)
439+
if client is not None:
440+
await asyncio.wait_for(client.running_event.wait(), timeout=timeout)
441+
393442
async def _init_mcp_client(self, name: str, config: dict) -> None:
394443
"""初始化单个MCP客户端"""
395444
# 先清理之前的客户端,如果存在
@@ -481,25 +530,15 @@ async def enable_mcp_server(
481530
"""
482531
if name in self.mcp_client_dict:
483532
return
484-
if not shutdown_event:
485-
shutdown_event = asyncio.Event()
486-
487-
timeout_value = _resolve_timeout(
488-
timeout=timeout,
489-
env_name=ENABLE_MCP_TIMEOUT_ENV,
490-
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
491-
)
492-
lifecycle_task = await self._start_mcp_client_with_timeout(
533+
await self._start_mcp_server(
493534
name=name,
494535
cfg=config,
536+
timeout_env=ENABLE_MCP_TIMEOUT_ENV,
537+
default_timeout=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
495538
shutdown_event=shutdown_event,
496-
timeout=timeout_value,
539+
timeout=timeout,
497540
)
498541

499-
# 初始化成功后再注册,避免失败时暴露无效的 event
500-
self.mcp_client_task[name] = lifecycle_task
501-
self.mcp_client_event[name] = shutdown_event
502-
503542
async def disable_mcp_server(
504543
self,
505544
name: str | None = None,
@@ -518,23 +557,19 @@ async def disable_mcp_server(
518557
if event is None and task is None and name not in self.mcp_client_dict:
519558
return
520559

521-
if event:
560+
if event is not None:
522561
event.set()
562+
elif task is not None:
563+
logger.warning(
564+
f"MCP 服务 {name} 缺少 shutdown event,直接取消生命周期任务。"
565+
)
566+
if not task.done():
567+
task.cancel()
568+
await asyncio.gather(task, return_exceptions=True)
523569

524570
try:
525-
if task is not None:
526-
await asyncio.wait_for(asyncio.shield(task), timeout=timeout)
527-
else:
528-
client = self.mcp_client_dict.get(name)
529-
if client is not None:
530-
await asyncio.wait_for(
531-
client.running_event.wait(), timeout=timeout
532-
)
533-
except asyncio.TimeoutError:
534-
if task is not None and not task.done():
535-
task.cancel()
536-
await asyncio.gather(task, return_exceptions=True)
537-
raise
571+
if event is not None or task is None:
572+
await self._wait_mcp_lifecycle_task(name, timeout)
538573
finally:
539574
self.mcp_client_event.pop(name, None)
540575
self.mcp_client_task.pop(name, None)
@@ -544,12 +579,23 @@ async def disable_mcp_server(
544579
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
545580
]
546581
else:
547-
running_events = [
548-
client.running_event.wait() for client in self.mcp_client_dict.values()
549-
]
550-
lifecycle_tasks = list(self.mcp_client_task.values())
551582
for _, event in list(self.mcp_client_event.items()):
552583
event.set()
584+
585+
lifecycle_tasks: list[asyncio.Task[None]] = []
586+
for task_name, task in list(self.mcp_client_task.items()):
587+
if task_name not in self.mcp_client_event and not task.done():
588+
logger.warning(
589+
f"MCP 服务 {task_name} 缺少 shutdown event,直接取消生命周期任务。"
590+
)
591+
task.cancel()
592+
lifecycle_tasks.append(task)
593+
594+
running_events = [
595+
client.running_event.wait()
596+
for client_name, client in self.mcp_client_dict.items()
597+
if client_name not in self.mcp_client_task
598+
]
553599
# waiting for all clients to finish
554600
try:
555601
await asyncio.wait_for(

0 commit comments

Comments
 (0)