2626MAX_MCP_TIMEOUT_SECONDS = 300.0
2727
2828
29- class MCPInitTimeoutError (asyncio .TimeoutError ):
29+ class MCPInitError (Exception ):
30+ """Base exception for MCP initialization failures."""
31+
32+
33+ class MCPInitTimeoutError (asyncio .TimeoutError , MCPInitError ):
3034 """Raised when MCP client initialization exceeds the configured timeout."""
3135
3236
37+ class MCPAllServicesFailedError (MCPInitError ):
38+ """Raised when all configured MCP services fail to initialize."""
39+
40+
3341@dataclass
3442class MCPInitSummary :
3543 total : int
@@ -80,6 +88,20 @@ def _resolve_timeout(
8088 return timeout_value
8189
8290
91+ def _resolve_mcp_timeout (
92+ * ,
93+ timeout : float | int | str | None = None ,
94+ init_phase : bool ,
95+ ) -> float :
96+ env_name = MCP_INIT_TIMEOUT_ENV if init_phase else ENABLE_MCP_TIMEOUT_ENV
97+ default = (
98+ DEFAULT_MCP_INIT_TIMEOUT_SECONDS
99+ if init_phase
100+ else DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS
101+ )
102+ return _resolve_timeout (timeout = timeout , env_name = env_name , default = default )
103+
104+
83105SUPPORTED_TYPES = [
84106 "string" ,
85107 "number" ,
@@ -300,10 +322,7 @@ async def init_mcp_clients(self) -> MCPInitSummary:
300322 open (mcp_json_file , encoding = "utf-8" ),
301323 )["mcpServers" ]
302324
303- init_timeout = _resolve_timeout (
304- env_name = MCP_INIT_TIMEOUT_ENV ,
305- default = DEFAULT_MCP_INIT_TIMEOUT_SECONDS ,
306- )
325+ init_timeout = _resolve_mcp_timeout (init_phase = True )
307326 timeout_display = f"{ init_timeout :g} "
308327
309328 active_configs : list [tuple [str , dict , asyncio .Event ]] = []
@@ -358,7 +377,7 @@ async def init_mcp_clients(self) -> MCPInitSummary:
358377 )
359378 logger .info (f"MCP 服务初始化完成: { summary .success } /{ summary .total } 成功" )
360379 if summary .total > 0 and summary .success == 0 :
361- raise RuntimeError (
380+ raise MCPAllServicesFailedError (
362381 "全部 MCP 服务初始化失败,请检查 mcp_server.json 配置和服务器可用性。"
363382 )
364383 return summary
@@ -382,14 +401,21 @@ async def _run_mcp_client(
382401 if runtime and runtime .lifecycle_task is current_task :
383402 self .mcp_server_runtime .pop (name , None )
384403
385- async def _start_mcp_client_with_timeout (
404+ async def _start_mcp_server (
386405 self ,
387406 name : str ,
388407 cfg : dict ,
389- shutdown_event : asyncio .Event ,
408+ * ,
409+ shutdown_event : asyncio .Event | None = None ,
390410 timeout : float ,
391- ) -> asyncio .Task [None ]:
392- """启动 MCP 客户端:先初始化,成功后再启动长生命周期任务。"""
411+ ) -> None :
412+ """Initialize MCP server with timeout and register task/event together."""
413+ if name in self .mcp_server_runtime :
414+ return
415+
416+ if shutdown_event is None :
417+ shutdown_event = asyncio .Event ()
418+
393419 try :
394420 await asyncio .wait_for (
395421 self ._init_mcp_client (name , cfg ),
@@ -405,32 +431,10 @@ async def _start_mcp_client_with_timeout(
405431 await self ._terminate_mcp_client (name )
406432 raise
407433
408- return asyncio .create_task (
434+ lifecycle_task = asyncio .create_task (
409435 self ._run_mcp_client (name , shutdown_event ),
410436 name = f"mcp-client:{ name } " ,
411437 )
412-
413- async def _start_mcp_server (
414- self ,
415- name : str ,
416- cfg : dict ,
417- * ,
418- shutdown_event : asyncio .Event | None = None ,
419- timeout : float ,
420- ) -> None :
421- """Initialize MCP server with timeout and register task/event together."""
422- if name in self .mcp_client_dict :
423- return
424-
425- if shutdown_event is None :
426- shutdown_event = asyncio .Event ()
427-
428- lifecycle_task = await self ._start_mcp_client_with_timeout (
429- name = name ,
430- cfg = cfg ,
431- shutdown_event = shutdown_event ,
432- timeout = timeout ,
433- )
434438 self .mcp_server_runtime [name ] = _MCPServerRuntime (
435439 shutdown_event = shutdown_event ,
436440 lifecycle_task = lifecycle_task ,
@@ -537,14 +541,10 @@ async def enable_mcp_server(
537541 MCPInitTimeoutError: If initialization does not complete within timeout.
538542 Exception: If there is an error during initialization.
539543 """
540- if name in self .mcp_client_dict :
544+ if name in self .mcp_server_runtime :
541545 logger .info (f"MCP 服务 { name } 已存在,跳过重复启用。" )
542546 return
543- timeout_value = _resolve_timeout (
544- timeout = timeout ,
545- env_name = ENABLE_MCP_TIMEOUT_ENV ,
546- default = DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS ,
547- )
547+ timeout_value = _resolve_mcp_timeout (timeout = timeout , init_phase = False )
548548 await self ._start_mcp_server (
549549 name = name ,
550550 cfg = config ,
@@ -566,19 +566,12 @@ async def disable_mcp_server(
566566 """
567567 if name :
568568 runtime = self .mcp_server_runtime .get (name )
569- if runtime is None and name not in self . mcp_client_dict :
569+ if runtime is None :
570570 return
571571
572572 try :
573- if runtime is not None :
574- runtime .shutdown_event .set ()
575- await self ._wait_mcp_lifecycle_task (runtime .lifecycle_task , timeout )
576- else :
577- client = self .mcp_client_dict .get (name )
578- if client is not None :
579- await asyncio .wait_for (
580- client .running_event .wait (), timeout = timeout
581- )
573+ runtime .shutdown_event .set ()
574+ await self ._wait_mcp_lifecycle_task (runtime .lifecycle_task , timeout )
582575 finally :
583576 self .mcp_server_runtime .pop (name , None )
584577 self .func_list = [
@@ -587,22 +580,17 @@ async def disable_mcp_server(
587580 if not (isinstance (f , MCPTool ) and f .mcp_server_name == name )
588581 ]
589582 else :
590- for runtime in self .mcp_server_runtime .values ():
583+ runtimes = list (self .mcp_server_runtime .values ())
584+ for runtime in runtimes :
591585 runtime .shutdown_event .set ()
592586
593587 lifecycle_tasks : list [asyncio .Task [None ]] = [
594- runtime .lifecycle_task for runtime in self .mcp_server_runtime .values ()
595- ]
596-
597- running_events = [
598- client .running_event .wait ()
599- for client_name , client in self .mcp_client_dict .items ()
600- if client_name not in self .mcp_server_runtime
588+ runtime .lifecycle_task for runtime in runtimes
601589 ]
602590 # waiting for all clients to finish
603591 try :
604592 await asyncio .wait_for (
605- asyncio .gather (* running_events , * lifecycle_tasks ),
593+ asyncio .gather (* lifecycle_tasks ),
606594 timeout = timeout ,
607595 )
608596 except asyncio .TimeoutError :
0 commit comments