Skip to content

Commit 8132ce2

Browse files
committed
fix: correctly synchronize MCP client initialization
1 parent 38e99cf commit 8132ce2

File tree

1 file changed

+59
-18
lines changed

1 file changed

+59
-18
lines changed

astrbot/core/provider/func_tool_manager.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -214,39 +214,68 @@ async def init_mcp_clients(self) -> None:
214214
)["mcpServers"]
215215

216216
tasks: dict[str, asyncio.Task] = {}
217+
ready_futures: dict[str, asyncio.Future] = {}
217218

218219
for name, cfg in mcp_server_json_obj.items():
219220
if cfg.get("active", True):
220221
event = asyncio.Event()
222+
ready_future = asyncio.get_running_loop().create_future()
221223
task = asyncio.create_task(
222-
self._init_mcp_client_task_wrapper(name, cfg, event),
224+
self._init_mcp_client_task_wrapper(
225+
name,
226+
cfg,
227+
event,
228+
ready_future,
229+
),
223230
)
224231
tasks[name] = task
232+
ready_futures[name] = ready_future
225233
self.mcp_client_event[name] = event
226234

227-
if tasks:
228-
logger.info(f"等待 {len(tasks)} 个 MCP 服务初始化...")
235+
if ready_futures:
236+
logger.info(f"等待 {len(ready_futures)} 个 MCP 服务初始化...")
229237

230-
done, pending = await asyncio.wait(tasks.values(), timeout=20.0)
238+
_, pending_futures = await asyncio.wait(
239+
ready_futures.values(),
240+
timeout=20.0,
241+
)
242+
243+
pending_services = {
244+
name
245+
for name, ready_future in ready_futures.items()
246+
if ready_future in pending_futures
247+
}
231248

232-
if pending:
249+
if pending_services:
233250
logger.warning(
234251
"MCP 服务初始化超时(20秒),部分服务可能未完全加载。"
235252
"建议检查 MCP 服务器配置和网络连接。"
236253
)
237-
for task in pending:
254+
for name in pending_services:
255+
task = tasks[name]
238256
task.cancel()
257+
await asyncio.gather(
258+
*(tasks[name] for name in pending_services),
259+
return_exceptions=True,
260+
)
239261

240262
success_count = 0
241263
failed_services: list[str] = []
242264

243-
for name, task in tasks.items():
244-
if task in pending:
265+
for name, ready_future in ready_futures.items():
266+
if name in pending_services:
245267
logger.error(f"MCP 服务 {name} 初始化超时")
246268
failed_services.append(name)
269+
self.mcp_client_event.pop(name, None)
270+
continue
271+
272+
if ready_future.cancelled():
273+
logger.error(f"MCP 服务 {name} 初始化已取消")
274+
failed_services.append(name)
275+
self.mcp_client_event.pop(name, None)
247276
continue
248277

249-
exc = task.exception()
278+
exc = ready_future.exception()
250279
if exc is not None:
251280
logger.error(f"MCP 服务 {name} 初始化失败: {exc}")
252281
# 仅在 debug 级别输出完整配置,避免在生产日志中泄露敏感信息
@@ -259,6 +288,7 @@ async def init_mcp_clients(self) -> None:
259288
parsed = urllib.parse.urlparse(cfg["url"])
260289
logger.debug(f" 主机: {parsed.scheme}://{parsed.netloc}")
261290
failed_services.append(name)
291+
self.mcp_client_event.pop(name, None)
262292
else:
263293
success_count += 1
264294

@@ -275,15 +305,26 @@ async def _init_mcp_client_task_wrapper(
275305
name: str,
276306
cfg: dict,
277307
event: asyncio.Event,
308+
ready_future: asyncio.Future | None = None,
278309
) -> None:
279310
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
280311
initialized = False
281312
try:
282313
await self._init_mcp_client(name, cfg)
283314
initialized = True
315+
if ready_future and not ready_future.done():
316+
ready_future.set_result(True)
284317
await event.wait()
285318
logger.info(f"收到 MCP 客户端 {name} 终止信号")
286-
except Exception:
319+
except asyncio.CancelledError:
320+
if ready_future and not ready_future.done():
321+
ready_future.set_exception(
322+
asyncio.TimeoutError("MCP 客户端初始化超时"),
323+
)
324+
raise
325+
except Exception as e:
326+
if ready_future and not ready_future.done():
327+
ready_future.set_exception(e)
287328
if not initialized:
288329
# 初始化阶段失败,记录错误并向上抛出让 task.exception() 捕获
289330
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
@@ -387,22 +428,22 @@ async def enable_mcp_server(
387428
if not event:
388429
event = asyncio.Event()
389430
if not ready_future:
390-
ready_future = asyncio.Future()
431+
ready_future = asyncio.get_running_loop().create_future()
391432
if name in self.mcp_client_dict:
392433
return
393-
asyncio.create_task(
434+
init_task = asyncio.create_task(
394435
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
395436
)
396437
try:
397438
await asyncio.wait_for(ready_future, timeout=timeout)
398-
finally:
439+
except asyncio.TimeoutError:
440+
init_task.cancel()
441+
await asyncio.gather(init_task, return_exceptions=True)
442+
self.mcp_client_event.pop(name, None)
443+
raise
444+
else:
399445
self.mcp_client_event[name] = event
400446

401-
if ready_future.done() and ready_future.exception():
402-
exc = ready_future.exception()
403-
if exc is not None:
404-
raise exc
405-
406447
async def disable_mcp_server(
407448
self,
408449
name: str | None = None,

0 commit comments

Comments
 (0)