@@ -35,26 +35,29 @@ def _resolve_timeout(
3535 env_name : str = MCP_INIT_TIMEOUT_ENV ,
3636 default : float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS ,
3737) -> float :
38+ source = f"环境变量 { env_name } "
3839 if timeout is None :
3940 timeout = os .getenv (env_name , str (default ))
41+ else :
42+ source = "显式参数 timeout"
4043
4144 try :
4245 timeout_value = float (timeout )
4346 except (TypeError , ValueError ):
4447 logger .warning (
45- f"超时配置 { env_name } ={ timeout !r} 无效,使用默认值 { default :g} 秒。"
48+ f"超时配置( { source } ) ={ timeout !r} 无效,使用默认值 { default :g} 秒。"
4649 )
4750 return default
4851
4952 if timeout_value <= 0 :
5053 logger .warning (
51- f"超时配置 { env_name } ={ timeout_value :g} 必须大于 0,使用默认值 { default :g} 秒。"
54+ f"超时配置( { source } ) ={ timeout_value :g} 必须大于 0,使用默认值 { default :g} 秒。"
5255 )
5356 return default
5457
5558 if timeout_value > MAX_MCP_TIMEOUT_SECONDS :
5659 logger .warning (
57- f"超时配置 { env_name } ={ timeout_value :g} 过大,已限制为最大值 "
60+ f"超时配置( { source } ) ={ timeout_value :g} 过大,已限制为最大值 "
5861 f"{ MAX_MCP_TIMEOUT_SECONDS :g} 秒,以避免长时间等待。"
5962 )
6063 return MAX_MCP_TIMEOUT_SECONDS
@@ -302,9 +305,11 @@ async def init_mcp_clients(self) -> None:
302305
303306 init_tasks = [
304307 asyncio .create_task (
305- self ._start_mcp_client_with_timeout (
308+ self ._start_mcp_server (
306309 name = name ,
307310 cfg = cfg ,
311+ timeout_env = MCP_INIT_TIMEOUT_ENV ,
312+ default_timeout = DEFAULT_MCP_INIT_TIMEOUT_SECONDS ,
308313 shutdown_event = shutdown_event ,
309314 timeout = init_timeout ,
310315 ),
@@ -317,9 +322,7 @@ async def init_mcp_clients(self) -> None:
317322 success_count = 0
318323 failed_services : list [str ] = []
319324
320- for (name , cfg , shutdown_event ), result in zip (
321- active_configs , results , strict = False
322- ):
325+ for (name , cfg , _ ), result in zip (active_configs , results , strict = False ):
323326 if isinstance (result , Exception ):
324327 if isinstance (result , MCPInitTimeoutError ):
325328 logger .error (f"MCP 服务 { name } 初始化超时({ timeout_display } 秒)" )
@@ -331,8 +334,6 @@ async def init_mcp_clients(self) -> None:
331334 self .mcp_client_task .pop (name , None )
332335 continue
333336
334- self .mcp_client_task [name ] = result
335- self .mcp_client_event [name ] = shutdown_event
336337 success_count += 1
337338
338339 if failed_services :
@@ -390,6 +391,54 @@ async def _start_mcp_client_with_timeout(
390391 name = f"mcp-client:{ name } " ,
391392 )
392393
394+ async def _start_mcp_server (
395+ self ,
396+ name : str ,
397+ cfg : dict ,
398+ * ,
399+ timeout_env : str ,
400+ default_timeout : float ,
401+ shutdown_event : asyncio .Event | None = None ,
402+ timeout : float | int | str | None = None ,
403+ ) -> None :
404+ """Initialize MCP server with timeout and register task/event together."""
405+ if name in self .mcp_client_dict :
406+ return
407+
408+ if shutdown_event is None :
409+ shutdown_event = asyncio .Event ()
410+
411+ timeout_value = _resolve_timeout (
412+ timeout = timeout ,
413+ env_name = timeout_env ,
414+ default = default_timeout ,
415+ )
416+ lifecycle_task = await self ._start_mcp_client_with_timeout (
417+ name = name ,
418+ cfg = cfg ,
419+ shutdown_event = shutdown_event ,
420+ timeout = timeout_value ,
421+ )
422+ self .mcp_client_task [name ] = lifecycle_task
423+ self .mcp_client_event [name ] = shutdown_event
424+
425+ async def _wait_mcp_lifecycle_task (self , name : str , timeout : float ) -> None :
426+ """Wait for lifecycle task first; fallback to client running_event."""
427+ task = self .mcp_client_task .get (name )
428+ if task is not None :
429+ try :
430+ await asyncio .wait_for (asyncio .shield (task ), timeout = timeout )
431+ except asyncio .TimeoutError :
432+ if not task .done ():
433+ task .cancel ()
434+ await asyncio .gather (task , return_exceptions = True )
435+ raise
436+ return
437+
438+ client = self .mcp_client_dict .get (name )
439+ if client is not None :
440+ await asyncio .wait_for (client .running_event .wait (), timeout = timeout )
441+
393442 async def _init_mcp_client (self , name : str , config : dict ) -> None :
394443 """初始化单个MCP客户端"""
395444 # 先清理之前的客户端,如果存在
@@ -481,25 +530,15 @@ async def enable_mcp_server(
481530 """
482531 if name in self .mcp_client_dict :
483532 return
484- if not shutdown_event :
485- shutdown_event = asyncio .Event ()
486-
487- timeout_value = _resolve_timeout (
488- timeout = timeout ,
489- env_name = ENABLE_MCP_TIMEOUT_ENV ,
490- default = DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS ,
491- )
492- lifecycle_task = await self ._start_mcp_client_with_timeout (
533+ await self ._start_mcp_server (
493534 name = name ,
494535 cfg = config ,
536+ timeout_env = ENABLE_MCP_TIMEOUT_ENV ,
537+ default_timeout = DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS ,
495538 shutdown_event = shutdown_event ,
496- timeout = timeout_value ,
539+ timeout = timeout ,
497540 )
498541
499- # 初始化成功后再注册,避免失败时暴露无效的 event
500- self .mcp_client_task [name ] = lifecycle_task
501- self .mcp_client_event [name ] = shutdown_event
502-
503542 async def disable_mcp_server (
504543 self ,
505544 name : str | None = None ,
@@ -518,23 +557,19 @@ async def disable_mcp_server(
518557 if event is None and task is None and name not in self .mcp_client_dict :
519558 return
520559
521- if event :
560+ if event is not None :
522561 event .set ()
562+ elif task is not None :
563+ logger .warning (
564+ f"MCP 服务 { name } 缺少 shutdown event,直接取消生命周期任务。"
565+ )
566+ if not task .done ():
567+ task .cancel ()
568+ await asyncio .gather (task , return_exceptions = True )
523569
524570 try :
525- if task is not None :
526- await asyncio .wait_for (asyncio .shield (task ), timeout = timeout )
527- else :
528- client = self .mcp_client_dict .get (name )
529- if client is not None :
530- await asyncio .wait_for (
531- client .running_event .wait (), timeout = timeout
532- )
533- except asyncio .TimeoutError :
534- if task is not None and not task .done ():
535- task .cancel ()
536- await asyncio .gather (task , return_exceptions = True )
537- raise
571+ if event is not None or task is None :
572+ await self ._wait_mcp_lifecycle_task (name , timeout )
538573 finally :
539574 self .mcp_client_event .pop (name , None )
540575 self .mcp_client_task .pop (name , None )
@@ -544,12 +579,23 @@ async def disable_mcp_server(
544579 if not (isinstance (f , MCPTool ) and f .mcp_server_name == name )
545580 ]
546581 else :
547- running_events = [
548- client .running_event .wait () for client in self .mcp_client_dict .values ()
549- ]
550- lifecycle_tasks = list (self .mcp_client_task .values ())
551582 for _ , event in list (self .mcp_client_event .items ()):
552583 event .set ()
584+
585+ lifecycle_tasks : list [asyncio .Task [None ]] = []
586+ for task_name , task in list (self .mcp_client_task .items ()):
587+ if task_name not in self .mcp_client_event and not task .done ():
588+ logger .warning (
589+ f"MCP 服务 { task_name } 缺少 shutdown event,直接取消生命周期任务。"
590+ )
591+ task .cancel ()
592+ lifecycle_tasks .append (task )
593+
594+ running_events = [
595+ client .running_event .wait ()
596+ for client_name , client in self .mcp_client_dict .items ()
597+ if client_name not in self .mcp_client_task
598+ ]
553599 # waiting for all clients to finish
554600 try :
555601 await asyncio .wait_for (
0 commit comments