2525MAX_MCP_TIMEOUT_SECONDS = 300.0
2626
2727
28+ class MCPInitTimeoutError (asyncio .TimeoutError ):
29+ """Raised when MCP client initialization exceeds the configured timeout."""
30+
31+
2832def _resolve_timeout (
2933 timeout : float | int | str | None = None ,
3034 * ,
@@ -150,6 +154,7 @@ def __init__(self) -> None:
150154 self .mcp_client_dict : dict [str , MCPClient ] = {}
151155 """MCP 服务列表"""
152156 self .mcp_client_event : dict [str , asyncio .Event ] = {}
157+ self .mcp_client_task : dict [str , asyncio .Task [None ]] = {}
153158
154159 def empty (self ) -> bool :
155160 return len (self .func_list ) == 0
@@ -316,15 +321,17 @@ async def init_mcp_clients(self) -> None:
316321 active_configs , results , strict = False
317322 ):
318323 if isinstance (result , Exception ):
319- if isinstance (result , TimeoutError ):
324+ if isinstance (result , MCPInitTimeoutError ):
320325 logger .error (f"MCP 服务 { name } 初始化超时({ timeout_display } 秒)" )
321326 else :
322327 logger .error (f"MCP 服务 { name } 初始化失败: { result } " )
323328 self ._log_safe_mcp_debug_config (cfg )
324329 failed_services .append (name )
325330 self .mcp_client_event .pop (name , None )
331+ self .mcp_client_task .pop (name , None )
326332 continue
327333
334+ self .mcp_client_task [name ] = result
328335 self .mcp_client_event [name ] = shutdown_event
329336 success_count += 1
330337
@@ -350,23 +357,29 @@ async def _run_mcp_client(
350357 raise
351358 finally :
352359 await self ._terminate_mcp_client (name )
360+ current_task = asyncio .current_task ()
361+ if self .mcp_client_task .get (name ) is current_task :
362+ self .mcp_client_task .pop (name , None )
363+ self .mcp_client_event .pop (name , None )
353364
354365 async def _start_mcp_client_with_timeout (
355366 self ,
356367 name : str ,
357368 cfg : dict ,
358369 shutdown_event : asyncio .Event ,
359370 timeout : float ,
360- ) -> asyncio .Task :
371+ ) -> asyncio .Task [ None ] :
361372 """启动 MCP 客户端:先初始化,成功后再启动长生命周期任务。"""
362373 try :
363374 await asyncio .wait_for (
364375 self ._init_mcp_client (name , cfg ),
365376 timeout = timeout ,
366377 )
367- except asyncio .TimeoutError :
378+ except asyncio .TimeoutError as exc :
368379 await self ._terminate_mcp_client (name )
369- raise TimeoutError (f"MCP 服务 { name } 初始化超时({ timeout :g} 秒)" )
380+ raise MCPInitTimeoutError (
381+ f"MCP 服务 { name } 初始化超时({ timeout :g} 秒)"
382+ ) from exc
370383 except Exception :
371384 logger .error (f"初始化 MCP 客户端 { name } 失败" , exc_info = True )
372385 await self ._terminate_mcp_client (name )
@@ -463,7 +476,7 @@ async def enable_mcp_server(
463476 timeout: Timeout in seconds for initialization.
464477
465478 Raises:
466- TimeoutError : If initialization does not complete within the timeout.
479+ MCPInitTimeoutError : If initialization does not complete within timeout.
467480 Exception: If there is an error during initialization.
468481 """
469482 if name in self .mcp_client_dict :
@@ -476,14 +489,15 @@ async def enable_mcp_server(
476489 env_name = ENABLE_MCP_TIMEOUT_ENV ,
477490 default = DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS ,
478491 )
479- await self ._start_mcp_client_with_timeout (
492+ lifecycle_task = await self ._start_mcp_client_with_timeout (
480493 name = name ,
481494 cfg = config ,
482495 shutdown_event = shutdown_event ,
483496 timeout = timeout_value ,
484497 )
485498
486499 # 初始化成功后再注册,避免失败时暴露无效的 event
500+ self .mcp_client_task [name ] = lifecycle_task
487501 self .mcp_client_event [name ] = shutdown_event
488502
489503 async def disable_mcp_server (
@@ -499,17 +513,31 @@ async def disable_mcp_server(
499513
500514 """
501515 if name :
502- if name not in self .mcp_client_event :
503- return
504- client = self .mcp_client_dict .get (name )
505- self .mcp_client_event [name ].set ()
506- if not client :
516+ event = self .mcp_client_event .get (name )
517+ task = self .mcp_client_task .get (name )
518+ if event is None and task is None and name not in self .mcp_client_dict :
507519 return
508- client_running_event = client .running_event
520+
521+ if event :
522+ event .set ()
523+
509524 try :
510- await asyncio .wait_for (client_running_event .wait (), timeout = timeout )
525+ if task is not None :
526+ await asyncio .wait_for (asyncio .shield (task ), timeout = timeout )
527+ else :
528+ client = self .mcp_client_dict .get (name )
529+ if client is not None :
530+ await asyncio .wait_for (
531+ client .running_event .wait (), timeout = timeout
532+ )
533+ except asyncio .TimeoutError :
534+ if task is not None and not task .done ():
535+ task .cancel ()
536+ await asyncio .gather (task , return_exceptions = True )
537+ raise
511538 finally :
512539 self .mcp_client_event .pop (name , None )
540+ self .mcp_client_task .pop (name , None )
513541 self .func_list = [
514542 f
515543 for f in self .func_list
@@ -519,13 +547,25 @@ async def disable_mcp_server(
519547 running_events = [
520548 client .running_event .wait () for client in self .mcp_client_dict .values ()
521549 ]
522- for key , event in self .mcp_client_event .items ():
550+ lifecycle_tasks = list (self .mcp_client_task .values ())
551+ for _ , event in list (self .mcp_client_event .items ()):
523552 event .set ()
524553 # waiting for all clients to finish
525554 try :
526- await asyncio .wait_for (asyncio .gather (* running_events ), timeout = timeout )
555+ await asyncio .wait_for (
556+ asyncio .gather (* running_events , * lifecycle_tasks ),
557+ timeout = timeout ,
558+ )
559+ except asyncio .TimeoutError :
560+ for task in lifecycle_tasks :
561+ if not task .done ():
562+ task .cancel ()
563+ if lifecycle_tasks :
564+ await asyncio .gather (* lifecycle_tasks , return_exceptions = True )
565+ raise
527566 finally :
528567 self .mcp_client_event .clear ()
568+ self .mcp_client_task .clear ()
529569 self .mcp_client_dict .clear ()
530570 self .func_list = [
531571 f for f in self .func_list if not isinstance (f , MCPTool )
0 commit comments