From 20908f8c99d736c9838e40b05940b7fc2d535943 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 15:55:22 +0900 Subject: [PATCH 01/14] feat: add MCPServerManager for safely managing server lifecycle --- examples/mcp/manager_example/README.md | 71 +++++ examples/mcp/manager_example/app.py | 130 ++++++++++ examples/mcp/manager_example/mcp_server.py | 26 ++ src/agents/agent.py | 3 +- src/agents/mcp/__init__.py | 2 + src/agents/mcp/manager.py | 288 +++++++++++++++++++++ src/agents/mcp/server.py | 2 + tests/mcp/test_mcp_server_manager.py | 127 +++++++++ 8 files changed, 648 insertions(+), 1 deletion(-) create mode 100644 examples/mcp/manager_example/README.md create mode 100644 examples/mcp/manager_example/app.py create mode 100644 examples/mcp/manager_example/mcp_server.py create mode 100644 src/agents/mcp/manager.py create mode 100644 tests/mcp/test_mcp_server_manager.py diff --git a/examples/mcp/manager_example/README.md b/examples/mcp/manager_example/README.md new file mode 100644 index 0000000000..e465c3f8de --- /dev/null +++ b/examples/mcp/manager_example/README.md @@ -0,0 +1,71 @@ +# MCP Manager Example (FastAPI) + +This example shows how to use `MCPServerManager` to keep MCP server lifecycle +management in a single task inside a FastAPI app with the Streamable HTTP +transport. + +## Run the MCP server (Streamable HTTP) + +``` +uv run python examples/mcp/manager_example/mcp_server.py +``` + +The server listens at `http://localhost:8000/mcp` by default. + +You can override the host/port with: + +``` +export STREAMABLE_HTTP_HOST=127.0.0.1 +export STREAMABLE_HTTP_PORT=8000 +``` + +This example also configures an inactive MCP server at +`http://localhost:8001/mcp` to demonstrate how the manager drops failed +servers. You can override it with: + +``` +export INACTIVE_MCP_SERVER_URL=http://localhost:8001/mcp +``` + +## Run the FastAPI app + +``` +uv run python examples/mcp/manager_example/app.py +``` + +The app listens at `http://127.0.0.1:9001`. + +## Toggle MCP manager usage + +By default, the app uses `MCPServerManager`. To disable it: + +``` +export USE_MCP_MANAGER=0 +``` + +## Try the endpoints + +``` +curl http://127.0.0.1:9001/health +curl http://127.0.0.1:9001/tools +curl -X POST http://127.0.0.1:9001/add \ + -H 'Content-Type: application/json' \ + -d '{"a": 2, "b": 3}' +``` + +Reconnect failed MCP servers (manager must be enabled): + +``` +curl -X POST http://127.0.0.1:9001/reconnect \ + -H 'Content-Type: application/json' \ + -d '{"failed_only": true}' +``` + +To use `/run`, set `OPENAI_API_KEY`: + +``` +export OPENAI_API_KEY=... +curl -X POST http://127.0.0.1:9001/run \ + -H 'Content-Type: application/json' \ + -d '{"input": "Add 4 and 9."}' +``` diff --git a/examples/mcp/manager_example/app.py b/examples/mcp/manager_example/app.py new file mode 100644 index 0000000000..cae0eb7501 --- /dev/null +++ b/examples/mcp/manager_example/app.py @@ -0,0 +1,130 @@ +import os +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from agents import Agent, Runner +from agents.mcp import MCPServer, MCPServerManager, MCPServerStreamableHttp +from agents.model_settings import ModelSettings + +MCP_SERVER_URL = os.getenv("MCP_SERVER_URL", "http://localhost:8000/mcp") +INACTIVE_MCP_SERVER_URL = os.getenv("INACTIVE_MCP_SERVER_URL", "http://localhost:8001/mcp") +APP_HOST = "127.0.0.1" +APP_PORT = 9001 +USE_MCP_MANAGER = os.getenv("USE_MCP_MANAGER", "1") != "0" + + +class AddRequest(BaseModel): + a: int + b: int + + +class RunRequest(BaseModel): + input: str + + +class ReconnectRequest(BaseModel): + failed_only: bool = True + + +@asynccontextmanager +async def lifespan(app: FastAPI): + server = MCPServerStreamableHttp({"url": MCP_SERVER_URL}) + inactive_server = MCPServerStreamableHttp({"url": INACTIVE_MCP_SERVER_URL}) + servers = [server, inactive_server] + if USE_MCP_MANAGER: + async with MCPServerManager( + servers=servers, + connect_in_parallel=True, + ) as manager: + app.state.mcp_manager = manager + app.state.mcp_servers = servers + yield + return + + await server.connect() + app.state.mcp_servers = servers + app.state.active_servers = [server] + try: + yield + finally: + await server.cleanup() + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/health") +async def health() -> dict[str, object]: + if USE_MCP_MANAGER: + manager: MCPServerManager = app.state.mcp_manager + return { + "connected_servers": [server.name for server in manager.active_servers], + "failed_servers": [server.name for server in manager.failed_servers], + } + + active_servers = _get_active_servers() + return { + "connected_servers": [server.name for server in active_servers], + "failed_servers": [], + } + + +@app.get("/tools") +async def list_tools() -> dict[str, object]: + active_servers = _get_active_servers() + if not active_servers: + return {"tools": []} + tools = await active_servers[0].list_tools() + return {"tools": [tool.name for tool in tools]} + + +@app.post("/add") +async def add(req: AddRequest) -> dict[str, object]: + active_servers = _get_active_servers() + if not active_servers: + raise HTTPException(status_code=503, detail="No MCP servers available") + result = await active_servers[0].call_tool("add", {"a": req.a, "b": req.b}) + return {"result": result.model_dump(mode="json")} + + +@app.post("/run") +async def run_agent(req: RunRequest) -> dict[str, object]: + if not os.getenv("OPENAI_API_KEY"): + raise HTTPException(status_code=400, detail="OPENAI_API_KEY is required") + + servers = _get_active_servers() + if not servers: + raise HTTPException(status_code=503, detail="No MCP servers available") + + agent = Agent( + name="FastAPI Agent", + instructions="Use the MCP tools when needed.", + mcp_servers=servers, + model_settings=ModelSettings(tool_choice="auto"), + ) + result = await Runner.run(starting_agent=agent, input=req.input) + return {"output": result.final_output} + + +@app.post("/reconnect") +async def reconnect(req: ReconnectRequest) -> dict[str, object]: + if not USE_MCP_MANAGER: + raise HTTPException(status_code=400, detail="MCPServerManager is disabled") + manager: MCPServerManager = app.state.mcp_manager + servers = await manager.reconnect(failed_only=req.failed_only) + return {"connected_servers": [server.name for server in servers]} + + +def _get_active_servers() -> list[MCPServer]: + if USE_MCP_MANAGER: + manager: MCPServerManager = app.state.mcp_manager + return list(manager.active_servers) + return list(app.state.active_servers) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host=APP_HOST, port=APP_PORT) diff --git a/examples/mcp/manager_example/mcp_server.py b/examples/mcp/manager_example/mcp_server.py new file mode 100644 index 0000000000..a67c224994 --- /dev/null +++ b/examples/mcp/manager_example/mcp_server.py @@ -0,0 +1,26 @@ +import os + +from mcp.server.fastmcp import FastMCP + +STREAMABLE_HTTP_HOST = os.getenv("STREAMABLE_HTTP_HOST", "127.0.0.1") +STREAMABLE_HTTP_PORT = int(os.getenv("STREAMABLE_HTTP_PORT", "8000")) + +mcp = FastMCP( + "FastAPI Example Server", + host=STREAMABLE_HTTP_HOST, + port=STREAMABLE_HTTP_PORT, +) + + +@mcp.tool() +def add(a: int, b: int) -> int: + return a + b + + +@mcp.tool() +def echo(message: str) -> str: + return f"echo: {message}" + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/src/agents/agent.py b/src/agents/agent.py index d8c7d19e20..7beed4a8c3 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -118,7 +118,8 @@ class AgentBase(Generic[TContext]): NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no - longer needed. + longer needed. Consider using `MCPServerManager` from `agents.mcp` to keep connect/cleanup + in the same task. """ mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index da5a68b16a..ed64a03bdc 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -1,4 +1,5 @@ try: + from .manager import MCPServerManager from .server import ( MCPServer, MCPServerSse, @@ -28,6 +29,7 @@ "MCPServerStdioParams", "MCPServerStreamableHttp", "MCPServerStreamableHttpParams", + "MCPServerManager", "MCPUtil", "ToolFilter", "ToolFilterCallable", diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py new file mode 100644 index 0000000000..2c4ac50c20 --- /dev/null +++ b/src/agents/mcp/manager.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager +from dataclasses import dataclass +from typing import Any + +from ..logger import logger +from .server import MCPServer + + +@dataclass +class _ServerCommand: + action: str + timeout_seconds: float | None + future: asyncio.Future[None] + + +class _ServerWorker: + def __init__( + self, + server: MCPServer, + connect_timeout_seconds: float | None, + cleanup_timeout_seconds: float | None, + ) -> None: + self._server = server + self._connect_timeout_seconds = connect_timeout_seconds + self._cleanup_timeout_seconds = cleanup_timeout_seconds + self._queue: asyncio.Queue[_ServerCommand] = asyncio.Queue() + self._task = asyncio.create_task(self._run()) + + @property + def is_done(self) -> bool: + return self._task.done() + + async def connect(self) -> None: + await self._submit("connect", self._connect_timeout_seconds) + + async def cleanup(self) -> None: + await self._submit("cleanup", self._cleanup_timeout_seconds) + + async def _submit(self, action: str, timeout_seconds: float | None) -> None: + loop = asyncio.get_running_loop() + future: asyncio.Future[None] = loop.create_future() + await self._queue.put( + _ServerCommand(action=action, timeout_seconds=timeout_seconds, future=future) + ) + await future + + async def _run(self) -> None: + while True: + command = await self._queue.get() + should_exit = command.action == "cleanup" + try: + if command.action == "connect": + await self._run_with_timeout(self._server.connect, command.timeout_seconds) + elif command.action == "cleanup": + await self._run_with_timeout(self._server.cleanup, command.timeout_seconds) + else: + raise ValueError(f"Unknown command: {command.action}") + if not command.future.cancelled(): + command.future.set_result(None) + except BaseException as exc: + if not command.future.cancelled(): + command.future.set_exception(exc) + if should_exit: + return + + async def _run_with_timeout( + self, func: Callable[[], Awaitable[Any]], timeout_seconds: float | None + ) -> None: + if timeout_seconds is None: + await func() + return + await asyncio.wait_for(func(), timeout=timeout_seconds) + + +class MCPServerManager(AbstractAsyncContextManager["MCPServerManager"]): + """Manage MCP server lifecycles and expose only connected servers. + + Use this helper to keep MCP connect/cleanup on the same task and avoid + run failures when a server is unavailable. The manager will attempt to + connect each server and then expose the connected subset via + `active_servers`. + + Basic usage: + async with MCPServerManager([server_a, server_b]) as manager: + agent = Agent( + name="Assistant", + instructions="...", + mcp_servers=manager.active_servers, + ) + + FastAPI lifespan example: + @asynccontextmanager + async def lifespan(app: FastAPI): + async with MCPServerManager([server_a, server_b]) as manager: + app.state.mcp_manager = manager + yield + + app = FastAPI(lifespan=lifespan) + + Important behaviors: + - `active_servers` only includes servers that connected successfully. + `failed_servers` holds the failures and `errors` maps servers to errors. + - `drop_failed_servers=True` removes failed servers from `active_servers` + (recommended). If False, `active_servers` will still include all servers. + - `strict=True` raises on the first connection failure. If False, failures + are recorded and the run can proceed with the remaining servers. + - `reconnect(failed_only=True)` retries failed servers and refreshes + `active_servers`. + - `connect_in_parallel=True` uses a dedicated worker task per server to + allow concurrent connects while preserving task affinity for cleanup. + """ + + def __init__( + self, + servers: Iterable[MCPServer], + *, + connect_timeout_seconds: float | None = 10.0, + cleanup_timeout_seconds: float | None = 10.0, + drop_failed_servers: bool = True, + strict: bool = False, + suppress_cancelled_error: bool = True, + connect_in_parallel: bool = False, + ) -> None: + self._all_servers = list(servers) + self._active_servers = list(servers) + self.connect_timeout_seconds = connect_timeout_seconds + self.cleanup_timeout_seconds = cleanup_timeout_seconds + self.drop_failed_servers = drop_failed_servers + self.strict = strict + self.suppress_cancelled_error = suppress_cancelled_error + self.connect_in_parallel = connect_in_parallel + self._workers: dict[MCPServer, _ServerWorker] = {} + + self.failed_servers: list[MCPServer] = [] + self.errors: dict[MCPServer, BaseException] = {} + + @property + def active_servers(self) -> list[MCPServer]: + """Return the active MCP servers after connection attempts.""" + return list(self._active_servers) + + @property + def all_servers(self) -> list[MCPServer]: + """Return all MCP servers managed by this instance.""" + return list(self._all_servers) + + async def __aenter__(self) -> MCPServerManager: + await self.connect_all() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool | None: + await self.cleanup_all() + return None + + async def connect_all(self) -> list[MCPServer]: + """Connect all servers in order and return the active list.""" + self.failed_servers = [] + self.errors = {} + + if self.connect_in_parallel: + await self._connect_all_parallel(self._active_servers) + else: + for server in self._active_servers: + await self._attempt_connect(server) + + self._refresh_active_servers() + + return self._active_servers + + async def reconnect(self, *, failed_only: bool = True) -> list[MCPServer]: + """Reconnect servers and return the active list. + + Args: + failed_only: If True, only retry servers that previously failed. + If False, retry all servers. + """ + if failed_only: + servers_to_retry = list(self.failed_servers) + else: + servers_to_retry = list(self._all_servers) + self.failed_servers = [] + self.errors = {} + + if self.connect_in_parallel: + await self._connect_all_parallel(servers_to_retry) + else: + for server in servers_to_retry: + await self._attempt_connect(server) + + self._refresh_active_servers() + return self._active_servers + + async def cleanup_all(self) -> None: + """Cleanup all servers in reverse order.""" + for server in reversed(self._all_servers): + try: + await self._cleanup_server(server) + except asyncio.CancelledError as exc: + if not self.suppress_cancelled_error: + raise + logger.debug(f"Cleanup cancelled for MCP server '{server.name}': {exc}") + self.errors[server] = exc + except Exception as exc: + logger.exception(f"Failed to cleanup MCP server '{server.name}': {exc}") + self.errors[server] = exc + + async def _run_with_timeout( + self, func: Callable[[], Awaitable[Any]], timeout_seconds: float | None + ) -> None: + if timeout_seconds is None: + await func() + return + await asyncio.wait_for(func(), timeout=timeout_seconds) + + async def _attempt_connect( + self, server: MCPServer, *, raise_on_error: bool | None = None + ) -> None: + if raise_on_error is None: + raise_on_error = self.strict + try: + await self._run_connect(server) + if server in self.failed_servers: + self.failed_servers.remove(server) + self.errors.pop(server, None) + except asyncio.CancelledError as exc: + if not self.suppress_cancelled_error: + raise + self._record_failure(server, exc, phase="connect") + except Exception as exc: + self._record_failure(server, exc, phase="connect") + if raise_on_error: + raise + + def _refresh_active_servers(self) -> None: + if self.drop_failed_servers: + failed = set(self.failed_servers) + self._active_servers = [server for server in self._all_servers if server not in failed] + else: + self._active_servers = list(self._all_servers) + + def _record_failure(self, server: MCPServer, exc: BaseException, phase: str) -> None: + logger.exception(f"Failed to {phase} MCP server '{server.name}': {exc}") + self.failed_servers.append(server) + self.errors[server] = exc + + async def _run_connect(self, server: MCPServer) -> None: + if self.connect_in_parallel: + worker = self._get_worker(server) + await worker.connect() + else: + await self._run_with_timeout(server.connect, self.connect_timeout_seconds) + + async def _cleanup_server(self, server: MCPServer) -> None: + if self.connect_in_parallel and server in self._workers: + worker = self._workers[server] + await worker.cleanup() + if worker.is_done: + self._workers.pop(server, None) + return + await self._run_with_timeout(server.cleanup, self.cleanup_timeout_seconds) + + async def _connect_all_parallel(self, servers: list[MCPServer]) -> None: + tasks = [ + asyncio.create_task(self._attempt_connect(server, raise_on_error=False)) + for server in servers + ] + await asyncio.gather(*tasks, return_exceptions=True) + if self.strict and self.failed_servers: + first_failure = self.failed_servers[0] + error = self.errors.get(first_failure) + if error is not None: + raise error + raise RuntimeError(f"Failed to connect MCP server '{first_failure.name}'") + + def _get_worker(self, server: MCPServer) -> _ServerWorker: + worker = self._workers.get(server) + if worker is None or worker.is_done: + worker = _ServerWorker( + server=server, + connect_timeout_seconds=self.connect_timeout_seconds, + cleanup_timeout_seconds=self.cleanup_timeout_seconds, + ) + self._workers[server] = worker + return worker diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 015b5b6f76..1468cf7f8f 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -451,6 +451,8 @@ async def cleanup(self): try: await self.exit_stack.aclose() + except asyncio.CancelledError as e: + logger.debug(f"Cleanup cancelled for MCP server '{self.name}': {e}") except BaseExceptionGroup as eg: # Extract HTTP errors from ExceptionGroup raised during cleanup # This happens when background tasks fail (e.g., HTTP errors) diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py new file mode 100644 index 0000000000..71bfe174e8 --- /dev/null +++ b/tests/mcp/test_mcp_server_manager.py @@ -0,0 +1,127 @@ +import asyncio +from typing import Any + +import pytest +from mcp.types import CallToolResult, GetPromptResult, ListPromptsResult, Tool as MCPTool + +from agents.mcp import MCPServer, MCPServerManager +from agents.run_context import RunContextWrapper + + +class TaskBoundServer(MCPServer): + def __init__(self) -> None: + super().__init__() + self._connect_task: asyncio.Task[object] | None = None + self.cleaned = False + + @property + def name(self) -> str: + return "task-bound" + + async def connect(self) -> None: + self._connect_task = asyncio.current_task() + + async def cleanup(self) -> None: + if self._connect_task is None: + raise RuntimeError("Server was not connected") + if asyncio.current_task() is not self._connect_task: + raise RuntimeError("Attempted to exit cancel scope in a different task") + self.cleaned = True + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + +class FlakyServer(MCPServer): + def __init__(self, failures: int) -> None: + super().__init__() + self.failures_remaining = failures + self.connect_calls = 0 + + @property + def name(self) -> str: + return "flaky" + + async def connect(self) -> None: + self.connect_calls += 1 + if self.failures_remaining > 0: + self.failures_remaining -= 1 + raise RuntimeError("connect failed") + + async def cleanup(self) -> None: + return None + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + +@pytest.mark.asyncio +async def test_manager_keeps_connect_and_cleanup_in_same_task() -> None: + server = TaskBoundServer() + + async with MCPServerManager([server]) as manager: + assert manager.active_servers == [server] + + assert server.cleaned is True + + +@pytest.mark.asyncio +async def test_manager_connects_in_worker_tasks_when_parallel() -> None: + server = TaskBoundServer() + + async with MCPServerManager([server], connect_in_parallel=True) as manager: + assert manager.active_servers == [server] + assert server._connect_task is not None + assert server._connect_task is not asyncio.current_task() + + assert server.cleaned is True + + +@pytest.mark.asyncio +async def test_cross_task_cleanup_raises_without_manager() -> None: + server = TaskBoundServer() + + connect_task = asyncio.create_task(server.connect()) + await connect_task + + with pytest.raises(RuntimeError, match="cancel scope"): + await server.cleanup() + + +@pytest.mark.asyncio +async def test_manager_reconnect_failed_only() -> None: + server = FlakyServer(failures=1) + + async with MCPServerManager([server]) as manager: + assert manager.active_servers == [] + assert manager.failed_servers == [server] + + await manager.reconnect() + assert manager.active_servers == [server] + assert manager.failed_servers == [] From 48cb48eef60d86c1e9322c163847450dc15de2bc Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 16:15:59 +0900 Subject: [PATCH 02/14] Fix MCPServerManager failure deduplication --- src/agents/mcp/manager.py | 30 ++++++++++++++++++++++++---- tests/mcp/test_mcp_server_manager.py | 20 +++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 2c4ac50c20..c53504acc5 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -136,6 +136,7 @@ def __init__( self._workers: dict[MCPServer, _ServerWorker] = {} self.failed_servers: list[MCPServer] = [] + self._failed_server_set: set[MCPServer] = set() self.errors: dict[MCPServer, BaseException] = {} @property @@ -159,6 +160,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool | None: async def connect_all(self) -> list[MCPServer]: """Connect all servers in order and return the active list.""" self.failed_servers = [] + self._failed_server_set = set() self.errors = {} if self.connect_in_parallel: @@ -179,10 +181,11 @@ async def reconnect(self, *, failed_only: bool = True) -> list[MCPServer]: If False, retry all servers. """ if failed_only: - servers_to_retry = list(self.failed_servers) + servers_to_retry = self._unique_servers(self.failed_servers) else: servers_to_retry = list(self._all_servers) self.failed_servers = [] + self._failed_server_set = set() self.errors = {} if self.connect_in_parallel: @@ -224,7 +227,7 @@ async def _attempt_connect( try: await self._run_connect(server) if server in self.failed_servers: - self.failed_servers.remove(server) + self._remove_failed_server(server) self.errors.pop(server, None) except asyncio.CancelledError as exc: if not self.suppress_cancelled_error: @@ -237,14 +240,16 @@ async def _attempt_connect( def _refresh_active_servers(self) -> None: if self.drop_failed_servers: - failed = set(self.failed_servers) + failed = set(self._failed_server_set) self._active_servers = [server for server in self._all_servers if server not in failed] else: self._active_servers = list(self._all_servers) def _record_failure(self, server: MCPServer, exc: BaseException, phase: str) -> None: logger.exception(f"Failed to {phase} MCP server '{server.name}': {exc}") - self.failed_servers.append(server) + if server not in self._failed_server_set: + self.failed_servers.append(server) + self._failed_server_set.add(server) self.errors[server] = exc async def _run_connect(self, server: MCPServer) -> None: @@ -286,3 +291,20 @@ def _get_worker(self, server: MCPServer) -> _ServerWorker: ) self._workers[server] = worker return worker + + def _remove_failed_server(self, server: MCPServer) -> None: + if server in self._failed_server_set: + self._failed_server_set.remove(server) + self.failed_servers = [ + failed_server for failed_server in self.failed_servers if failed_server != server + ] + + @staticmethod + def _unique_servers(servers: Iterable[MCPServer]) -> list[MCPServer]: + seen: set[MCPServer] = set() + unique: list[MCPServer] = [] + for server in servers: + if server not in seen: + seen.add(server) + unique.append(server) + return unique diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 71bfe174e8..1c95b3dd8b 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -125,3 +125,23 @@ async def test_manager_reconnect_failed_only() -> None: await manager.reconnect() assert manager.active_servers == [server] assert manager.failed_servers == [] + + +@pytest.mark.asyncio +async def test_manager_reconnect_deduplicates_failures() -> None: + server = FlakyServer(failures=2) + + async with MCPServerManager([server], connect_in_parallel=True) as manager: + assert manager.active_servers == [] + assert manager.failed_servers == [server] + assert server.connect_calls == 1 + + await manager.reconnect() + assert manager.active_servers == [] + assert manager.failed_servers == [server] + assert server.connect_calls == 2 + + await manager.reconnect() + assert manager.active_servers == [server] + assert manager.failed_servers == [] + assert server.connect_calls == 3 From 0d2005ee56ea1549dabf6b45ad12f0ddd2dcffcb Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 16:19:09 +0900 Subject: [PATCH 03/14] fix old python issues --- src/agents/mcp/manager.py | 48 ++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index c53504acc5..5513188c47 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -54,9 +54,9 @@ async def _run(self) -> None: should_exit = command.action == "cleanup" try: if command.action == "connect": - await self._run_with_timeout(self._server.connect, command.timeout_seconds) + await _run_with_timeout_in_task(self._server.connect, command.timeout_seconds) elif command.action == "cleanup": - await self._run_with_timeout(self._server.cleanup, command.timeout_seconds) + await _run_with_timeout_in_task(self._server.cleanup, command.timeout_seconds) else: raise ValueError(f"Unknown command: {command.action}") if not command.future.cancelled(): @@ -67,13 +67,42 @@ async def _run(self) -> None: if should_exit: return - async def _run_with_timeout( - self, func: Callable[[], Awaitable[Any]], timeout_seconds: float | None - ) -> None: - if timeout_seconds is None: + +async def _run_with_timeout_in_task( + func: Callable[[], Awaitable[Any]], timeout_seconds: float | None +) -> None: + # Use an in-task timeout to preserve task affinity for MCP cleanup. + # asyncio.wait_for creates a new Task on Python < 3.11, which breaks + # libraries that require connect/cleanup in the same task (e.g. AnyIO cancel scopes). + if timeout_seconds is None: + await func() + return + timeout_context = getattr(asyncio, "timeout", None) + if timeout_context is not None: + async with timeout_context(timeout_seconds): await func() - return + return + task = asyncio.current_task() + if task is None: await asyncio.wait_for(func(), timeout=timeout_seconds) + return + timed_out = False + loop = asyncio.get_running_loop() + + def _cancel() -> None: + nonlocal timed_out + timed_out = True + task.cancel() + + handle = loop.call_later(timeout_seconds, _cancel) + try: + await func() + except asyncio.CancelledError as exc: + if timed_out: + raise asyncio.TimeoutError() from exc + raise + finally: + handle.cancel() class MCPServerManager(AbstractAsyncContextManager["MCPServerManager"]): @@ -214,10 +243,7 @@ async def cleanup_all(self) -> None: async def _run_with_timeout( self, func: Callable[[], Awaitable[Any]], timeout_seconds: float | None ) -> None: - if timeout_seconds is None: - await func() - return - await asyncio.wait_for(func(), timeout=timeout_seconds) + await _run_with_timeout_in_task(func, timeout_seconds) async def _attempt_connect( self, server: MCPServer, *, raise_on_error: bool | None = None From dcbe02f7e74fcdd6c2bb5497258133c237e195d7 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 16:26:37 +0900 Subject: [PATCH 04/14] Fix MCPServerManager connect_all retries --- src/agents/mcp/manager.py | 5 +++-- tests/mcp/test_mcp_server_manager.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 5513188c47..3f9f74cb9b 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -192,10 +192,11 @@ async def connect_all(self) -> list[MCPServer]: self._failed_server_set = set() self.errors = {} + servers_to_connect = list(self._all_servers) if self.connect_in_parallel: - await self._connect_all_parallel(self._active_servers) + await self._connect_all_parallel(servers_to_connect) else: - for server in self._active_servers: + for server in servers_to_connect: await self._attempt_connect(server) self._refresh_active_servers() diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 1c95b3dd8b..085425fa8c 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -145,3 +145,21 @@ async def test_manager_reconnect_deduplicates_failures() -> None: assert manager.active_servers == [server] assert manager.failed_servers == [] assert server.connect_calls == 3 + + +@pytest.mark.asyncio +async def test_manager_connect_all_retries_all_servers() -> None: + server = FlakyServer(failures=1) + manager = MCPServerManager([server]) + try: + await manager.connect_all() + assert manager.active_servers == [] + assert manager.failed_servers == [server] + assert server.connect_calls == 1 + + await manager.connect_all() + assert manager.active_servers == [server] + assert manager.failed_servers == [] + assert server.connect_calls == 2 + finally: + await manager.cleanup_all() From f4bbcbc08515cebbd2bb3ee4d02c703ad406d079 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 16:29:01 +0900 Subject: [PATCH 05/14] Clean up MCP servers on strict connect failure --- src/agents/mcp/manager.py | 35 ++++++++++++++++++++++++---- tests/mcp/test_mcp_server_manager.py | 13 +++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 3f9f74cb9b..c579317926 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -193,11 +193,23 @@ async def connect_all(self) -> list[MCPServer]: self.errors = {} servers_to_connect = list(self._all_servers) - if self.connect_in_parallel: - await self._connect_all_parallel(servers_to_connect) - else: - for server in servers_to_connect: - await self._attempt_connect(server) + connected_servers: list[MCPServer] = [] + try: + if self.connect_in_parallel: + await self._connect_all_parallel(servers_to_connect) + else: + for server in servers_to_connect: + await self._attempt_connect(server) + if server not in self._failed_server_set: + connected_servers.append(server) + except Exception: + if self.connect_in_parallel: + connected_servers = [ + server for server in servers_to_connect if server not in self._failed_server_set + ] + await self._cleanup_connected_servers(connected_servers) + self._active_servers = [] + raise self._refresh_active_servers() @@ -295,6 +307,19 @@ async def _cleanup_server(self, server: MCPServer) -> None: return await self._run_with_timeout(server.cleanup, self.cleanup_timeout_seconds) + async def _cleanup_connected_servers(self, servers: Iterable[MCPServer]) -> None: + for server in reversed(list(servers)): + try: + await self._cleanup_server(server) + except asyncio.CancelledError as exc: + if not self.suppress_cancelled_error: + raise + logger.debug(f"Cleanup cancelled for MCP server '{server.name}': {exc}") + self.errors[server] = exc + except Exception as exc: + logger.exception(f"Failed to cleanup MCP server '{server.name}': {exc}") + self.errors[server] = exc + async def _connect_all_parallel(self, servers: list[MCPServer]) -> None: tasks = [ asyncio.create_task(self._attempt_connect(server, raise_on_error=False)) diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 085425fa8c..e2ed78c865 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -163,3 +163,16 @@ async def test_manager_connect_all_retries_all_servers() -> None: assert server.connect_calls == 2 finally: await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_manager_strict_connect_cleans_up_connected_servers() -> None: + connected_server = TaskBoundServer() + failing_server = FlakyServer(failures=1) + manager = MCPServerManager([connected_server, failing_server], strict=True) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert connected_server.cleaned is True + assert manager.active_servers == [] From db461ea8338691c18a3a3371da7775ebcdeb3822 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 16:46:17 +0900 Subject: [PATCH 06/14] Clean up failed MCP connects in strict mode --- src/agents/mcp/manager.py | 3 ++- tests/mcp/test_mcp_server_manager.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index c579317926..87098df904 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -207,7 +207,8 @@ async def connect_all(self) -> list[MCPServer]: connected_servers = [ server for server in servers_to_connect if server not in self._failed_server_set ] - await self._cleanup_connected_servers(connected_servers) + servers_to_cleanup = self._unique_servers([*connected_servers, *self.failed_servers]) + await self._cleanup_connected_servers(servers_to_cleanup) self._active_servers = [] raise diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index e2ed78c865..75e8ac32eb 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -81,6 +81,16 @@ async def get_prompt( raise NotImplementedError +class FailingTaskBoundServer(TaskBoundServer): + @property + def name(self) -> str: + return "failing-task-bound" + + async def connect(self) -> None: + await super().connect() + raise RuntimeError("connect failed") + + @pytest.mark.asyncio async def test_manager_keeps_connect_and_cleanup_in_same_task() -> None: server = TaskBoundServer() @@ -176,3 +186,25 @@ async def test_manager_strict_connect_cleans_up_connected_servers() -> None: assert connected_server.cleaned is True assert manager.active_servers == [] + + +@pytest.mark.asyncio +async def test_manager_strict_connect_cleans_up_failed_server() -> None: + failing_server = FailingTaskBoundServer() + manager = MCPServerManager([failing_server], strict=True) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert failing_server.cleaned is True + + +@pytest.mark.asyncio +async def test_manager_strict_connect_parallel_cleans_up_failed_server() -> None: + failing_server = FailingTaskBoundServer() + manager = MCPServerManager([failing_server], strict=True, connect_in_parallel=True) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert failing_server.cleaned is True From e0ba8c3ad8261f94b626c01992fd3dfcdfb3582f Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 17:05:06 +0900 Subject: [PATCH 07/14] Fix MCP server reconnect state handling --- src/agents/mcp/manager.py | 50 +++++++---- src/agents/mcp/server.py | 1 + tests/mcp/test_mcp_server_manager.py | 124 +++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 16 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 87098df904..9c8248462f 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -166,6 +166,7 @@ def __init__( self.failed_servers: list[MCPServer] = [] self._failed_server_set: set[MCPServer] = set() + self._connected_servers: set[MCPServer] = set() self.errors: dict[MCPServer, BaseException] = {} @property @@ -192,7 +193,7 @@ async def connect_all(self) -> list[MCPServer]: self._failed_server_set = set() self.errors = {} - servers_to_connect = list(self._all_servers) + servers_to_connect = self._servers_to_connect(self._all_servers) connected_servers: list[MCPServer] = [] try: if self.connect_in_parallel: @@ -204,11 +205,12 @@ async def connect_all(self) -> list[MCPServer]: connected_servers.append(server) except Exception: if self.connect_in_parallel: - connected_servers = [ - server for server in servers_to_connect if server not in self._failed_server_set - ] - servers_to_cleanup = self._unique_servers([*connected_servers, *self.failed_servers]) - await self._cleanup_connected_servers(servers_to_cleanup) + await self._cleanup_servers(servers_to_connect) + else: + servers_to_cleanup = self._unique_servers( + [*connected_servers, *self.failed_servers] + ) + await self._cleanup_servers(servers_to_cleanup) self._active_servers = [] raise @@ -221,23 +223,26 @@ async def reconnect(self, *, failed_only: bool = True) -> list[MCPServer]: Args: failed_only: If True, only retry servers that previously failed. - If False, retry all servers. + If False, cleanup and retry all servers. """ if failed_only: servers_to_retry = self._unique_servers(self.failed_servers) else: + await self.cleanup_all() servers_to_retry = list(self._all_servers) self.failed_servers = [] self._failed_server_set = set() self.errors = {} - if self.connect_in_parallel: - await self._connect_all_parallel(servers_to_retry) - else: - for server in servers_to_retry: - await self._attempt_connect(server) - - self._refresh_active_servers() + servers_to_retry = self._servers_to_connect(servers_to_retry) + try: + if self.connect_in_parallel: + await self._connect_all_parallel(servers_to_retry) + else: + for server in servers_to_retry: + await self._attempt_connect(server) + finally: + self._refresh_active_servers() return self._active_servers async def cleanup_all(self) -> None: @@ -266,6 +271,7 @@ async def _attempt_connect( raise_on_error = self.strict try: await self._run_connect(server) + self._connected_servers.add(server) if server in self.failed_servers: self._remove_failed_server(server) self.errors.pop(server, None) @@ -305,10 +311,12 @@ async def _cleanup_server(self, server: MCPServer) -> None: await worker.cleanup() if worker.is_done: self._workers.pop(server, None) + self._connected_servers.discard(server) return await self._run_with_timeout(server.cleanup, self.cleanup_timeout_seconds) + self._connected_servers.discard(server) - async def _cleanup_connected_servers(self, servers: Iterable[MCPServer]) -> None: + async def _cleanup_servers(self, servers: Iterable[MCPServer]) -> None: for server in reversed(list(servers)): try: await self._cleanup_server(server) @@ -326,7 +334,11 @@ async def _connect_all_parallel(self, servers: list[MCPServer]) -> None: asyncio.create_task(self._attempt_connect(server, raise_on_error=False)) for server in servers ] - await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*tasks, return_exceptions=True) + if not self.suppress_cancelled_error: + for result in results: + if isinstance(result, asyncio.CancelledError): + raise result if self.strict and self.failed_servers: first_failure = self.failed_servers[0] error = self.errors.get(first_failure) @@ -352,6 +364,12 @@ def _remove_failed_server(self, server: MCPServer) -> None: failed_server for failed_server in self.failed_servers if failed_server != server ] + def _servers_to_connect(self, servers: Iterable[MCPServer]) -> list[MCPServer]: + unique = self._unique_servers(servers) + if not self._connected_servers: + return unique + return [server for server in unique if server not in self._connected_servers] + @staticmethod def _unique_servers(servers: Iterable[MCPServer]) -> list[MCPServer]: seen: set[MCPServer] = set() diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 1468cf7f8f..ad1371a244 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -453,6 +453,7 @@ async def cleanup(self): await self.exit_stack.aclose() except asyncio.CancelledError as e: logger.debug(f"Cleanup cancelled for MCP server '{self.name}': {e}") + raise except BaseExceptionGroup as eg: # Extract HTTP errors from ExceptionGroup raised during cleanup # This happens when background tasks fail (e.g., HTTP errors) diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 75e8ac32eb..85a8e1f40d 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -81,6 +81,69 @@ async def get_prompt( raise NotImplementedError +class CleanupAwareServer(MCPServer): + def __init__(self) -> None: + super().__init__() + self.connect_calls = 0 + self.cleanup_calls = 0 + + @property + def name(self) -> str: + return "cleanup-aware" + + async def connect(self) -> None: + if self.connect_calls > self.cleanup_calls: + raise RuntimeError("connect called without cleanup") + self.connect_calls += 1 + + async def cleanup(self) -> None: + self.cleanup_calls += 1 + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + +class CancelledServer(MCPServer): + @property + def name(self) -> str: + return "cancelled" + + async def connect(self) -> None: + raise asyncio.CancelledError() + + async def cleanup(self) -> None: + return None + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + class FailingTaskBoundServer(TaskBoundServer): @property def name(self) -> str: @@ -175,6 +238,40 @@ async def test_manager_connect_all_retries_all_servers() -> None: await manager.cleanup_all() +@pytest.mark.asyncio +async def test_manager_connect_all_is_idempotent() -> None: + server = CleanupAwareServer() + + async with MCPServerManager([server]) as manager: + assert server.connect_calls == 1 + await manager.connect_all() + + +@pytest.mark.asyncio +async def test_manager_reconnect_all_avoids_duplicate_connections() -> None: + server = CleanupAwareServer() + + async with MCPServerManager([server]) as manager: + assert server.connect_calls == 1 + await manager.reconnect(failed_only=False) + + +@pytest.mark.asyncio +async def test_manager_strict_reconnect_refreshes_active_servers() -> None: + server_a = FlakyServer(failures=1) + server_b = FlakyServer(failures=2) + + async with MCPServerManager([server_a, server_b]) as manager: + assert manager.active_servers == [] + + manager.strict = True + with pytest.raises(RuntimeError, match="connect failed"): + await manager.reconnect() + + assert manager.active_servers == [server_a] + assert manager.failed_servers == [server_b] + + @pytest.mark.asyncio async def test_manager_strict_connect_cleans_up_connected_servers() -> None: connected_server = TaskBoundServer() @@ -208,3 +305,30 @@ async def test_manager_strict_connect_parallel_cleans_up_failed_server() -> None await manager.connect_all() assert failing_server.cleaned is True + + +@pytest.mark.asyncio +async def test_manager_strict_connect_parallel_cleans_up_workers() -> None: + connected_server = TaskBoundServer() + failing_server = FailingTaskBoundServer() + manager = MCPServerManager( + [connected_server, failing_server], strict=True, connect_in_parallel=True + ) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert connected_server.cleaned is True + assert failing_server.cleaned is True + assert manager._workers == {} + + +@pytest.mark.asyncio +async def test_manager_parallel_propagates_cancelled_error_when_unsuppressed() -> None: + server = CancelledServer() + manager = MCPServerManager([server], connect_in_parallel=True, suppress_cancelled_error=False) + try: + with pytest.raises(asyncio.CancelledError): + await manager.connect_all() + finally: + await manager.cleanup_all() From 0097a9a442d763ee526357f33b1eede19c163596 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 17:34:48 +0900 Subject: [PATCH 08/14] Preserve active servers on strict connect failures --- src/agents/mcp/manager.py | 29 +++++++++++++++++++------ tests/mcp/test_mcp_server_manager.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 9c8248462f..4f1d46edc7 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -189,6 +189,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool | None: async def connect_all(self) -> list[MCPServer]: """Connect all servers in order and return the active list.""" + previous_connected_servers = set(self._connected_servers) + previous_active_servers = list(self._active_servers) self.failed_servers = [] self._failed_server_set = set() self.errors = {} @@ -211,7 +213,12 @@ async def connect_all(self) -> list[MCPServer]: [*connected_servers, *self.failed_servers] ) await self._cleanup_servers(servers_to_cleanup) - self._active_servers = [] + if self.drop_failed_servers: + self._active_servers = [ + server for server in self._all_servers if server in previous_connected_servers + ] + else: + self._active_servers = previous_active_servers raise self._refresh_active_servers() @@ -340,11 +347,21 @@ async def _connect_all_parallel(self, servers: list[MCPServer]) -> None: if isinstance(result, asyncio.CancelledError): raise result if self.strict and self.failed_servers: - first_failure = self.failed_servers[0] - error = self.errors.get(first_failure) - if error is not None: - raise error - raise RuntimeError(f"Failed to connect MCP server '{first_failure.name}'") + first_failure = None + if self.suppress_cancelled_error: + for server in self.failed_servers: + error = self.errors.get(server) + if error is None or isinstance(error, asyncio.CancelledError): + continue + first_failure = server + break + else: + first_failure = self.failed_servers[0] + if first_failure is not None: + error = self.errors.get(first_failure) + if error is not None: + raise error + raise RuntimeError(f"Failed to connect MCP server '{first_failure.name}'") def _get_worker(self, server: MCPServer) -> _ServerWorker: worker = self._workers.get(server) diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 85a8e1f40d..9f337e9756 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -272,6 +272,26 @@ async def test_manager_strict_reconnect_refreshes_active_servers() -> None: assert manager.failed_servers == [server_b] +@pytest.mark.asyncio +async def test_manager_strict_connect_preserves_existing_active_servers() -> None: + connected_server = TaskBoundServer() + failing_server = FlakyServer(failures=2) + manager = MCPServerManager([connected_server, failing_server]) + try: + await manager.connect_all() + assert manager.active_servers == [connected_server] + assert manager.failed_servers == [failing_server] + + manager.strict = True + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert manager.active_servers == [connected_server] + assert manager.failed_servers == [failing_server] + finally: + await manager.cleanup_all() + + @pytest.mark.asyncio async def test_manager_strict_connect_cleans_up_connected_servers() -> None: connected_server = TaskBoundServer() @@ -323,6 +343,18 @@ async def test_manager_strict_connect_parallel_cleans_up_workers() -> None: assert manager._workers == {} +@pytest.mark.asyncio +async def test_manager_parallel_suppresses_cancelled_error_in_strict_mode() -> None: + server = CancelledServer() + manager = MCPServerManager([server], connect_in_parallel=True, strict=True) + try: + await manager.connect_all() + assert manager.active_servers == [] + assert manager.failed_servers == [server] + finally: + await manager.cleanup_all() + + @pytest.mark.asyncio async def test_manager_parallel_propagates_cancelled_error_when_unsuppressed() -> None: server = CancelledServer() From cab79d173803de50dae39af3a81e5af86905ad37 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 17:40:32 +0900 Subject: [PATCH 09/14] Clear MCP worker state on cleanup failure --- src/agents/mcp/manager.py | 16 ++++++++++------ tests/mcp/test_mcp_server_manager.py | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 4f1d46edc7..907c8294fc 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -315,13 +315,17 @@ async def _run_connect(self, server: MCPServer) -> None: async def _cleanup_server(self, server: MCPServer) -> None: if self.connect_in_parallel and server in self._workers: worker = self._workers[server] - await worker.cleanup() - if worker.is_done: - self._workers.pop(server, None) - self._connected_servers.discard(server) + try: + await worker.cleanup() + finally: + if worker.is_done: + self._workers.pop(server, None) + self._connected_servers.discard(server) return - await self._run_with_timeout(server.cleanup, self.cleanup_timeout_seconds) - self._connected_servers.discard(server) + try: + await self._run_with_timeout(server.cleanup, self.cleanup_timeout_seconds) + finally: + self._connected_servers.discard(server) async def _cleanup_servers(self, servers: Iterable[MCPServer]) -> None: for server in reversed(list(servers)): diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 9f337e9756..1549b5f36c 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -154,6 +154,16 @@ async def connect(self) -> None: raise RuntimeError("connect failed") +class CleanupFailingServer(TaskBoundServer): + @property + def name(self) -> str: + return "cleanup-failing" + + async def cleanup(self) -> None: + await super().cleanup() + raise RuntimeError("cleanup failed") + + @pytest.mark.asyncio async def test_manager_keeps_connect_and_cleanup_in_same_task() -> None: server = TaskBoundServer() @@ -343,6 +353,17 @@ async def test_manager_strict_connect_parallel_cleans_up_workers() -> None: assert manager._workers == {} +@pytest.mark.asyncio +async def test_manager_parallel_cleanup_clears_worker_on_failure() -> None: + server = CleanupFailingServer() + manager = MCPServerManager([server], connect_in_parallel=True) + await manager.connect_all() + await manager.cleanup_all() + + assert server not in manager._workers + assert server not in manager._connected_servers + + @pytest.mark.asyncio async def test_manager_parallel_suppresses_cancelled_error_in_strict_mode() -> None: server = CancelledServer() From ad343e4bc821d66a81346436a5be7932d6f60968 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 17:45:25 +0900 Subject: [PATCH 10/14] Drop stale MCP workers after cleanup errors --- src/agents/mcp/manager.py | 7 +++++-- tests/mcp/test_mcp_server_manager.py | 25 ++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 907c8294fc..2fd990b5c5 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -315,11 +315,14 @@ async def _run_connect(self, server: MCPServer) -> None: async def _cleanup_server(self, server: MCPServer) -> None: if self.connect_in_parallel and server in self._workers: worker = self._workers[server] + if worker.is_done: + self._workers.pop(server, None) + self._connected_servers.discard(server) + return try: await worker.cleanup() finally: - if worker.is_done: - self._workers.pop(server, None) + self._workers.pop(server, None) self._connected_servers.discard(server) return try: diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 1549b5f36c..fab0d8cfdb 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any +from typing import Any, cast import pytest from mcp.types import CallToolResult, GetPromptResult, ListPromptsResult, Tool as MCPTool @@ -364,6 +364,29 @@ async def test_manager_parallel_cleanup_clears_worker_on_failure() -> None: assert server not in manager._connected_servers +@pytest.mark.asyncio +async def test_manager_parallel_cleanup_drops_worker_after_error() -> None: + class HangingCleanupWorker: + def __init__(self) -> None: + self.cleanup_calls = 0 + + @property + def is_done(self) -> bool: + return False + + async def cleanup(self) -> None: + self.cleanup_calls += 1 + raise RuntimeError("cleanup failed") + + server = FlakyServer(failures=0) + manager = MCPServerManager([server], connect_in_parallel=True) + manager._workers[server] = cast(Any, HangingCleanupWorker()) + + await manager.cleanup_all() + + assert manager._workers == {} + + @pytest.mark.asyncio async def test_manager_parallel_suppresses_cancelled_error_in_strict_mode() -> None: server = CancelledServer() From 2899176df2e33d1fec495c6e09696610bcd970fc Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 18:15:51 +0900 Subject: [PATCH 11/14] Ensure cleanup on connect cancellation --- src/agents/mcp/manager.py | 2 +- tests/mcp/test_mcp_server_manager.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 2fd990b5c5..d5ddf72e33 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -205,7 +205,7 @@ async def connect_all(self) -> list[MCPServer]: await self._attempt_connect(server) if server not in self._failed_server_set: connected_servers.append(server) - except Exception: + except BaseException: if self.connect_in_parallel: await self._cleanup_servers(servers_to_connect) else: diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index fab0d8cfdb..582e1f32df 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -408,3 +408,19 @@ async def test_manager_parallel_propagates_cancelled_error_when_unsuppressed() - await manager.connect_all() finally: await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_manager_cleanup_runs_on_cancelled_error_during_connect() -> None: + server = CleanupAwareServer() + cancelled_server = CancelledServer() + manager = MCPServerManager( + [server, cancelled_server], + suppress_cancelled_error=False, + ) + try: + with pytest.raises(asyncio.CancelledError): + await manager.connect_all() + assert server.cleanup_calls == 1 + finally: + await manager.cleanup_all() From e2c98362253a0ddd7d0d400e942a940c61c3efd9 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 18:51:10 +0900 Subject: [PATCH 12/14] Propagate base exceptions from parallel connect --- src/agents/mcp/manager.py | 3 +++ tests/mcp/test_mcp_server_manager.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index d5ddf72e33..20722c5a24 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -349,6 +349,9 @@ async def _connect_all_parallel(self, servers: list[MCPServer]) -> None: for server in servers ] results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, BaseException) and not isinstance(result, asyncio.CancelledError): + raise result if not self.suppress_cancelled_error: for result in results: if isinstance(result, asyncio.CancelledError): diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 582e1f32df..f7e0cd4eb6 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -154,6 +154,20 @@ async def connect(self) -> None: raise RuntimeError("connect failed") +class FatalError(BaseException): + pass + + +class FatalTaskBoundServer(TaskBoundServer): + @property + def name(self) -> str: + return "fatal-task-bound" + + async def connect(self) -> None: + await super().connect() + raise FatalError("fatal connect failed") + + class CleanupFailingServer(TaskBoundServer): @property def name(self) -> str: @@ -410,6 +424,18 @@ async def test_manager_parallel_propagates_cancelled_error_when_unsuppressed() - await manager.cleanup_all() +@pytest.mark.asyncio +async def test_manager_parallel_propagates_base_exception() -> None: + server = FatalTaskBoundServer() + manager = MCPServerManager([server], connect_in_parallel=True) + + with pytest.raises(FatalError, match="fatal connect failed"): + await manager.connect_all() + + assert server.cleaned is True + assert manager._workers == {} + + @pytest.mark.asyncio async def test_manager_cleanup_runs_on_cancelled_error_during_connect() -> None: server = CleanupAwareServer() From f02ed967f61e487d68f77db0cd12e4f5d6b5825c Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 19:03:41 +0900 Subject: [PATCH 13/14] Prefer cancellation over other errors --- src/agents/mcp/manager.py | 6 +++--- tests/mcp/test_mcp_server_manager.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 20722c5a24..96b2419d37 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -349,13 +349,13 @@ async def _connect_all_parallel(self, servers: list[MCPServer]) -> None: for server in servers ] results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, BaseException) and not isinstance(result, asyncio.CancelledError): - raise result if not self.suppress_cancelled_error: for result in results: if isinstance(result, asyncio.CancelledError): raise result + for result in results: + if isinstance(result, BaseException) and not isinstance(result, asyncio.CancelledError): + raise result if self.strict and self.failed_servers: first_failure = None if self.suppress_cancelled_error: diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index f7e0cd4eb6..24b1a834c7 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -436,6 +436,22 @@ async def test_manager_parallel_propagates_base_exception() -> None: assert manager._workers == {} +@pytest.mark.asyncio +async def test_manager_parallel_prefers_cancelled_error_when_unsuppressed() -> None: + cancelled_server = CancelledServer() + fatal_server = FatalTaskBoundServer() + manager = MCPServerManager( + [fatal_server, cancelled_server], + connect_in_parallel=True, + suppress_cancelled_error=False, + ) + try: + with pytest.raises(asyncio.CancelledError): + await manager.connect_all() + finally: + await manager.cleanup_all() + + @pytest.mark.asyncio async def test_manager_cleanup_runs_on_cancelled_error_during_connect() -> None: server = CleanupAwareServer() From 08414794164902b12a9f62cb1284bccf8c098d64 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 22 Jan 2026 21:10:16 +0900 Subject: [PATCH 14/14] Record base exception failures in sequential connect --- src/agents/mcp/manager.py | 3 +++ tests/mcp/test_mcp_server_manager.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py index 96b2419d37..2c70d6f9dd 100644 --- a/src/agents/mcp/manager.py +++ b/src/agents/mcp/manager.py @@ -290,6 +290,9 @@ async def _attempt_connect( self._record_failure(server, exc, phase="connect") if raise_on_error: raise + except BaseException as exc: + self._record_failure(server, exc, phase="connect") + raise def _refresh_active_servers(self) -> None: if self.drop_failed_servers: diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py index 24b1a834c7..b52aeb5311 100644 --- a/tests/mcp/test_mcp_server_manager.py +++ b/tests/mcp/test_mcp_server_manager.py @@ -424,6 +424,18 @@ async def test_manager_parallel_propagates_cancelled_error_when_unsuppressed() - await manager.cleanup_all() +@pytest.mark.asyncio +async def test_manager_sequential_propagates_base_exception() -> None: + server = FatalTaskBoundServer() + manager = MCPServerManager([server]) + + with pytest.raises(FatalError, match="fatal connect failed"): + await manager.connect_all() + + assert server.cleaned is True + assert manager.failed_servers == [server] + + @pytest.mark.asyncio async def test_manager_parallel_propagates_base_exception() -> None: server = FatalTaskBoundServer()