66import os
77import urllib .parse
88from collections .abc import AsyncGenerator , Awaitable , Callable
9+ from dataclasses import dataclass
910from typing import Any
1011
1112import aiohttp
@@ -29,12 +30,26 @@ class MCPInitTimeoutError(asyncio.TimeoutError):
2930 """Raised when MCP client initialization exceeds the configured timeout."""
3031
3132
33+ @dataclass
34+ class MCPInitSummary :
35+ total : int
36+ success : int
37+ failed : list [str ]
38+
39+
40+ @dataclass
41+ class _MCPServerRuntime :
42+ shutdown_event : asyncio .Event
43+ lifecycle_task : asyncio .Task [None ]
44+
45+
3246def _resolve_timeout (
3347 timeout : float | int | str | None = None ,
3448 * ,
3549 env_name : str = MCP_INIT_TIMEOUT_ENV ,
3650 default : float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS ,
3751) -> float :
52+ """Resolve timeout with precedence: explicit argument > env value > default."""
3853 source = f"环境变量 { env_name } "
3954 if timeout is None :
4055 timeout = os .getenv (env_name , str (default ))
@@ -156,8 +171,7 @@ def __init__(self) -> None:
156171 self .func_list : list [FuncTool ] = []
157172 self .mcp_client_dict : dict [str , MCPClient ] = {}
158173 """MCP 服务列表"""
159- self .mcp_client_event : dict [str , asyncio .Event ] = {}
160- self .mcp_client_task : dict [str , asyncio .Task [None ]] = {}
174+ self .mcp_server_runtime : dict [str , _MCPServerRuntime ] = {}
161175
162176 def empty (self ) -> bool :
163177 return len (self .func_list ) == 0
@@ -253,7 +267,7 @@ def _log_safe_mcp_debug_config(cfg: dict) -> None:
253267 port = ""
254268 logger .debug (f" 主机: { scheme } ://{ host } { port } " )
255269
256- async def init_mcp_clients (self ) -> None :
270+ async def init_mcp_clients (self ) -> MCPInitSummary :
257271 """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
258272 ```
259273 {
@@ -299,7 +313,7 @@ async def init_mcp_clients(self) -> None:
299313 active_configs .append ((name , cfg , shutdown_event ))
300314
301315 if not active_configs :
302- return
316+ return MCPInitSummary ( total = 0 , success = 0 , failed = [])
303317
304318 logger .info (f"等待 { len (active_configs )} 个 MCP 服务初始化..." )
305319
@@ -308,8 +322,6 @@ async def init_mcp_clients(self) -> None:
308322 self ._start_mcp_server (
309323 name = name ,
310324 cfg = cfg ,
311- timeout_env = MCP_INIT_TIMEOUT_ENV ,
312- default_timeout = DEFAULT_MCP_INIT_TIMEOUT_SECONDS ,
313325 shutdown_event = shutdown_event ,
314326 timeout = init_timeout ,
315327 ),
@@ -330,8 +342,7 @@ async def init_mcp_clients(self) -> None:
330342 logger .error (f"MCP 服务 { name } 初始化失败: { result } " )
331343 self ._log_safe_mcp_debug_config (cfg )
332344 failed_services .append (name )
333- self .mcp_client_event .pop (name , None )
334- self .mcp_client_task .pop (name , None )
345+ self .mcp_server_runtime .pop (name , None )
335346 continue
336347
337348 success_count += 1
@@ -342,7 +353,15 @@ async def init_mcp_clients(self) -> None:
342353 f"请检查配置文件 mcp_server.json 和服务器可用性。"
343354 )
344355
345- logger .info (f"MCP 服务初始化完成: { success_count } /{ len (active_configs )} 成功" )
356+ summary = MCPInitSummary (
357+ total = len (active_configs ), success = success_count , failed = failed_services
358+ )
359+ logger .info (f"MCP 服务初始化完成: { summary .success } /{ summary .total } 成功" )
360+ if summary .total > 0 and summary .success == 0 :
361+ raise RuntimeError (
362+ "全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
363+ )
364+ return summary
346365
347366 async def _run_mcp_client (
348367 self ,
@@ -359,9 +378,9 @@ async def _run_mcp_client(
359378 finally :
360379 await self ._terminate_mcp_client (name )
361380 current_task = asyncio .current_task ()
362- if self .mcp_client_task .get (name ) is current_task :
363- self . mcp_client_task . pop ( name , None )
364- self .mcp_client_event .pop (name , None )
381+ runtime = self .mcp_server_runtime .get (name )
382+ if runtime and runtime . lifecycle_task is current_task :
383+ self .mcp_server_runtime .pop (name , None )
365384
366385 async def _start_mcp_client_with_timeout (
367386 self ,
@@ -396,10 +415,8 @@ async def _start_mcp_server(
396415 name : str ,
397416 cfg : dict ,
398417 * ,
399- timeout_env : str ,
400- default_timeout : float ,
401418 shutdown_event : asyncio .Event | None = None ,
402- timeout : float | int | str | None = None ,
419+ timeout : float ,
403420 ) -> None :
404421 """Initialize MCP server with timeout and register task/event together."""
405422 if name in self .mcp_client_dict :
@@ -408,36 +425,28 @@ async def _start_mcp_server(
408425 if shutdown_event is None :
409426 shutdown_event = asyncio .Event ()
410427
411- timeout_value = _resolve_timeout (
412- timeout = timeout ,
413- env_name = timeout_env ,
414- default = default_timeout ,
415- )
416428 lifecycle_task = await self ._start_mcp_client_with_timeout (
417429 name = name ,
418430 cfg = cfg ,
419431 shutdown_event = shutdown_event ,
420- timeout = timeout_value ,
432+ timeout = timeout ,
433+ )
434+ self .mcp_server_runtime [name ] = _MCPServerRuntime (
435+ shutdown_event = shutdown_event ,
436+ lifecycle_task = lifecycle_task ,
421437 )
422- self .mcp_client_task [name ] = lifecycle_task
423- self .mcp_client_event [name ] = shutdown_event
424438
425- async def _wait_mcp_lifecycle_task (self , name : str , timeout : float ) -> None :
439+ async def _wait_mcp_lifecycle_task (
440+ self , lifecycle_task : asyncio .Task [None ], timeout : float
441+ ) -> None :
426442 """Wait for lifecycle task first; fallback to client running_event."""
427- task = self .mcp_client_task .get (name )
428- if task is not None :
429- try :
430- await asyncio .wait_for (asyncio .shield (task ), timeout = timeout )
431- except asyncio .TimeoutError :
432- if not task .done ():
433- task .cancel ()
434- await asyncio .gather (task , return_exceptions = True )
435- raise
436- return
437-
438- client = self .mcp_client_dict .get (name )
439- if client is not None :
440- await asyncio .wait_for (client .running_event .wait (), timeout = timeout )
443+ try :
444+ await asyncio .wait_for (asyncio .shield (lifecycle_task ), timeout = timeout )
445+ except asyncio .TimeoutError :
446+ if not lifecycle_task .done ():
447+ lifecycle_task .cancel ()
448+ await asyncio .gather (lifecycle_task , return_exceptions = True )
449+ raise
441450
442451 async def _init_mcp_client (self , name : str , config : dict ) -> None :
443452 """初始化单个MCP客户端"""
@@ -529,14 +538,18 @@ async def enable_mcp_server(
529538 Exception: If there is an error during initialization.
530539 """
531540 if name in self .mcp_client_dict :
541+ logger .info (f"MCP 服务 { name } 已存在,跳过重复启用。" )
532542 return
543+ timeout_value = _resolve_timeout (
544+ timeout = timeout ,
545+ env_name = ENABLE_MCP_TIMEOUT_ENV ,
546+ default = DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS ,
547+ )
533548 await self ._start_mcp_server (
534549 name = name ,
535550 cfg = config ,
536- timeout_env = ENABLE_MCP_TIMEOUT_ENV ,
537- default_timeout = DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS ,
538551 shutdown_event = shutdown_event ,
539- timeout = timeout ,
552+ timeout = timeout_value ,
540553 )
541554
542555 async def disable_mcp_server (
@@ -552,49 +565,39 @@ async def disable_mcp_server(
552565
553566 """
554567 if name :
555- event = self .mcp_client_event .get (name )
556- task = self .mcp_client_task .get (name )
557- if event is None and task is None and name not in self .mcp_client_dict :
568+ runtime = self .mcp_server_runtime .get (name )
569+ if runtime is None and name not in self .mcp_client_dict :
558570 return
559571
560- if event is not None :
561- event .set ()
562- elif task is not None :
563- logger .warning (
564- f"MCP 服务 { name } 缺少 shutdown event,直接取消生命周期任务。"
565- )
566- if not task .done ():
567- task .cancel ()
568- await asyncio .gather (task , return_exceptions = True )
569-
570572 try :
571- if event is not None or task is None :
572- await self ._wait_mcp_lifecycle_task (name , timeout )
573+ if runtime is not None :
574+ runtime .shutdown_event .set ()
575+ await self ._wait_mcp_lifecycle_task (runtime .lifecycle_task , timeout )
576+ else :
577+ client = self .mcp_client_dict .get (name )
578+ if client is not None :
579+ await asyncio .wait_for (
580+ client .running_event .wait (), timeout = timeout
581+ )
573582 finally :
574- self .mcp_client_event .pop (name , None )
575- self .mcp_client_task .pop (name , None )
583+ self .mcp_server_runtime .pop (name , None )
576584 self .func_list = [
577585 f
578586 for f in self .func_list
579587 if not (isinstance (f , MCPTool ) and f .mcp_server_name == name )
580588 ]
581589 else :
582- for _ , event in list (self .mcp_client_event .items ()):
583- event .set ()
584-
585- lifecycle_tasks : list [asyncio .Task [None ]] = []
586- for task_name , task in list (self .mcp_client_task .items ()):
587- if task_name not in self .mcp_client_event and not task .done ():
588- logger .warning (
589- f"MCP 服务 { task_name } 缺少 shutdown event,直接取消生命周期任务。"
590- )
591- task .cancel ()
592- lifecycle_tasks .append (task )
590+ for runtime in self .mcp_server_runtime .values ():
591+ runtime .shutdown_event .set ()
592+
593+ lifecycle_tasks : list [asyncio .Task [None ]] = [
594+ runtime .lifecycle_task for runtime in self .mcp_server_runtime .values ()
595+ ]
593596
594597 running_events = [
595598 client .running_event .wait ()
596599 for client_name , client in self .mcp_client_dict .items ()
597- if client_name not in self .mcp_client_task
600+ if client_name not in self .mcp_server_runtime
598601 ]
599602 # waiting for all clients to finish
600603 try :
@@ -610,8 +613,7 @@ async def disable_mcp_server(
610613 await asyncio .gather (* lifecycle_tasks , return_exceptions = True )
611614 raise
612615 finally :
613- self .mcp_client_event .clear ()
614- self .mcp_client_task .clear ()
616+ self .mcp_server_runtime .clear ()
615617 self .mcp_client_dict .clear ()
616618 self .func_list = [
617619 f for f in self .func_list if not isinstance (f , MCPTool )
0 commit comments