diff --git a/examples/mcp/manager_example/README.md b/examples/mcp/manager_example/README.md new file mode 100644 index 0000000000..e465c3f8de --- /dev/null +++ b/examples/mcp/manager_example/README.md @@ -0,0 +1,71 @@ +# MCP Manager Example (FastAPI) + +This example shows how to use `MCPServerManager` to keep MCP server lifecycle +management in a single task inside a FastAPI app with the Streamable HTTP +transport. + +## Run the MCP server (Streamable HTTP) + +``` +uv run python examples/mcp/manager_example/mcp_server.py +``` + +The server listens at `http://localhost:8000/mcp` by default. + +You can override the host/port with: + +``` +export STREAMABLE_HTTP_HOST=127.0.0.1 +export STREAMABLE_HTTP_PORT=8000 +``` + +This example also configures an inactive MCP server at +`http://localhost:8001/mcp` to demonstrate how the manager drops failed +servers. You can override it with: + +``` +export INACTIVE_MCP_SERVER_URL=http://localhost:8001/mcp +``` + +## Run the FastAPI app + +``` +uv run python examples/mcp/manager_example/app.py +``` + +The app listens at `http://127.0.0.1:9001`. + +## Toggle MCP manager usage + +By default, the app uses `MCPServerManager`. To disable it: + +``` +export USE_MCP_MANAGER=0 +``` + +## Try the endpoints + +``` +curl http://127.0.0.1:9001/health +curl http://127.0.0.1:9001/tools +curl -X POST http://127.0.0.1:9001/add \ + -H 'Content-Type: application/json' \ + -d '{"a": 2, "b": 3}' +``` + +Reconnect failed MCP servers (manager must be enabled): + +``` +curl -X POST http://127.0.0.1:9001/reconnect \ + -H 'Content-Type: application/json' \ + -d '{"failed_only": true}' +``` + +To use `/run`, set `OPENAI_API_KEY`: + +``` +export OPENAI_API_KEY=... +curl -X POST http://127.0.0.1:9001/run \ + -H 'Content-Type: application/json' \ + -d '{"input": "Add 4 and 9."}' +``` diff --git a/examples/mcp/manager_example/app.py b/examples/mcp/manager_example/app.py new file mode 100644 index 0000000000..cae0eb7501 --- /dev/null +++ b/examples/mcp/manager_example/app.py @@ -0,0 +1,130 @@ +import os +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from agents import Agent, Runner +from agents.mcp import MCPServer, MCPServerManager, MCPServerStreamableHttp +from agents.model_settings import ModelSettings + +MCP_SERVER_URL = os.getenv("MCP_SERVER_URL", "http://localhost:8000/mcp") +INACTIVE_MCP_SERVER_URL = os.getenv("INACTIVE_MCP_SERVER_URL", "http://localhost:8001/mcp") +APP_HOST = "127.0.0.1" +APP_PORT = 9001 +USE_MCP_MANAGER = os.getenv("USE_MCP_MANAGER", "1") != "0" + + +class AddRequest(BaseModel): + a: int + b: int + + +class RunRequest(BaseModel): + input: str + + +class ReconnectRequest(BaseModel): + failed_only: bool = True + + +@asynccontextmanager +async def lifespan(app: FastAPI): + server = MCPServerStreamableHttp({"url": MCP_SERVER_URL}) + inactive_server = MCPServerStreamableHttp({"url": INACTIVE_MCP_SERVER_URL}) + servers = [server, inactive_server] + if USE_MCP_MANAGER: + async with MCPServerManager( + servers=servers, + connect_in_parallel=True, + ) as manager: + app.state.mcp_manager = manager + app.state.mcp_servers = servers + yield + return + + await server.connect() + app.state.mcp_servers = servers + app.state.active_servers = [server] + try: + yield + finally: + await server.cleanup() + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/health") +async def health() -> dict[str, object]: + if USE_MCP_MANAGER: + manager: MCPServerManager = app.state.mcp_manager + return { + "connected_servers": [server.name for server in manager.active_servers], + "failed_servers": [server.name for server in manager.failed_servers], + } + + active_servers = _get_active_servers() + return { + "connected_servers": [server.name for server in active_servers], + "failed_servers": [], + } + + +@app.get("/tools") +async def list_tools() -> dict[str, object]: + active_servers = _get_active_servers() + if not active_servers: + return {"tools": []} + tools = await active_servers[0].list_tools() + return {"tools": [tool.name for tool in tools]} + + +@app.post("/add") +async def add(req: AddRequest) -> dict[str, object]: + active_servers = _get_active_servers() + if not active_servers: + raise HTTPException(status_code=503, detail="No MCP servers available") + result = await active_servers[0].call_tool("add", {"a": req.a, "b": req.b}) + return {"result": result.model_dump(mode="json")} + + +@app.post("/run") +async def run_agent(req: RunRequest) -> dict[str, object]: + if not os.getenv("OPENAI_API_KEY"): + raise HTTPException(status_code=400, detail="OPENAI_API_KEY is required") + + servers = _get_active_servers() + if not servers: + raise HTTPException(status_code=503, detail="No MCP servers available") + + agent = Agent( + name="FastAPI Agent", + instructions="Use the MCP tools when needed.", + mcp_servers=servers, + model_settings=ModelSettings(tool_choice="auto"), + ) + result = await Runner.run(starting_agent=agent, input=req.input) + return {"output": result.final_output} + + +@app.post("/reconnect") +async def reconnect(req: ReconnectRequest) -> dict[str, object]: + if not USE_MCP_MANAGER: + raise HTTPException(status_code=400, detail="MCPServerManager is disabled") + manager: MCPServerManager = app.state.mcp_manager + servers = await manager.reconnect(failed_only=req.failed_only) + return {"connected_servers": [server.name for server in servers]} + + +def _get_active_servers() -> list[MCPServer]: + if USE_MCP_MANAGER: + manager: MCPServerManager = app.state.mcp_manager + return list(manager.active_servers) + return list(app.state.active_servers) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host=APP_HOST, port=APP_PORT) diff --git a/examples/mcp/manager_example/mcp_server.py b/examples/mcp/manager_example/mcp_server.py new file mode 100644 index 0000000000..a67c224994 --- /dev/null +++ b/examples/mcp/manager_example/mcp_server.py @@ -0,0 +1,26 @@ +import os + +from mcp.server.fastmcp import FastMCP + +STREAMABLE_HTTP_HOST = os.getenv("STREAMABLE_HTTP_HOST", "127.0.0.1") +STREAMABLE_HTTP_PORT = int(os.getenv("STREAMABLE_HTTP_PORT", "8000")) + +mcp = FastMCP( + "FastAPI Example Server", + host=STREAMABLE_HTTP_HOST, + port=STREAMABLE_HTTP_PORT, +) + + +@mcp.tool() +def add(a: int, b: int) -> int: + return a + b + + +@mcp.tool() +def echo(message: str) -> str: + return f"echo: {message}" + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/src/agents/agent.py b/src/agents/agent.py index d8c7d19e20..7beed4a8c3 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -118,7 +118,8 @@ class AgentBase(Generic[TContext]): NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no - longer needed. + longer needed. Consider using `MCPServerManager` from `agents.mcp` to keep connect/cleanup + in the same task. """ mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index da5a68b16a..ed64a03bdc 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -1,4 +1,5 @@ try: + from .manager import MCPServerManager from .server import ( MCPServer, MCPServerSse, @@ -28,6 +29,7 @@ "MCPServerStdioParams", "MCPServerStreamableHttp", "MCPServerStreamableHttpParams", + "MCPServerManager", "MCPUtil", "ToolFilter", "ToolFilterCallable", diff --git a/src/agents/mcp/manager.py b/src/agents/mcp/manager.py new file mode 100644 index 0000000000..2c70d6f9dd --- /dev/null +++ b/src/agents/mcp/manager.py @@ -0,0 +1,411 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager +from dataclasses import dataclass +from typing import Any + +from ..logger import logger +from .server import MCPServer + + +@dataclass +class _ServerCommand: + action: str + timeout_seconds: float | None + future: asyncio.Future[None] + + +class _ServerWorker: + def __init__( + self, + server: MCPServer, + connect_timeout_seconds: float | None, + cleanup_timeout_seconds: float | None, + ) -> None: + self._server = server + self._connect_timeout_seconds = connect_timeout_seconds + self._cleanup_timeout_seconds = cleanup_timeout_seconds + self._queue: asyncio.Queue[_ServerCommand] = asyncio.Queue() + self._task = asyncio.create_task(self._run()) + + @property + def is_done(self) -> bool: + return self._task.done() + + async def connect(self) -> None: + await self._submit("connect", self._connect_timeout_seconds) + + async def cleanup(self) -> None: + await self._submit("cleanup", self._cleanup_timeout_seconds) + + async def _submit(self, action: str, timeout_seconds: float | None) -> None: + loop = asyncio.get_running_loop() + future: asyncio.Future[None] = loop.create_future() + await self._queue.put( + _ServerCommand(action=action, timeout_seconds=timeout_seconds, future=future) + ) + await future + + async def _run(self) -> None: + while True: + command = await self._queue.get() + should_exit = command.action == "cleanup" + try: + if command.action == "connect": + await _run_with_timeout_in_task(self._server.connect, command.timeout_seconds) + elif command.action == "cleanup": + await _run_with_timeout_in_task(self._server.cleanup, command.timeout_seconds) + else: + raise ValueError(f"Unknown command: {command.action}") + if not command.future.cancelled(): + command.future.set_result(None) + except BaseException as exc: + if not command.future.cancelled(): + command.future.set_exception(exc) + if should_exit: + return + + +async def _run_with_timeout_in_task( + func: Callable[[], Awaitable[Any]], timeout_seconds: float | None +) -> None: + # Use an in-task timeout to preserve task affinity for MCP cleanup. + # asyncio.wait_for creates a new Task on Python < 3.11, which breaks + # libraries that require connect/cleanup in the same task (e.g. AnyIO cancel scopes). + if timeout_seconds is None: + await func() + return + timeout_context = getattr(asyncio, "timeout", None) + if timeout_context is not None: + async with timeout_context(timeout_seconds): + await func() + return + task = asyncio.current_task() + if task is None: + await asyncio.wait_for(func(), timeout=timeout_seconds) + return + timed_out = False + loop = asyncio.get_running_loop() + + def _cancel() -> None: + nonlocal timed_out + timed_out = True + task.cancel() + + handle = loop.call_later(timeout_seconds, _cancel) + try: + await func() + except asyncio.CancelledError as exc: + if timed_out: + raise asyncio.TimeoutError() from exc + raise + finally: + handle.cancel() + + +class MCPServerManager(AbstractAsyncContextManager["MCPServerManager"]): + """Manage MCP server lifecycles and expose only connected servers. + + Use this helper to keep MCP connect/cleanup on the same task and avoid + run failures when a server is unavailable. The manager will attempt to + connect each server and then expose the connected subset via + `active_servers`. + + Basic usage: + async with MCPServerManager([server_a, server_b]) as manager: + agent = Agent( + name="Assistant", + instructions="...", + mcp_servers=manager.active_servers, + ) + + FastAPI lifespan example: + @asynccontextmanager + async def lifespan(app: FastAPI): + async with MCPServerManager([server_a, server_b]) as manager: + app.state.mcp_manager = manager + yield + + app = FastAPI(lifespan=lifespan) + + Important behaviors: + - `active_servers` only includes servers that connected successfully. + `failed_servers` holds the failures and `errors` maps servers to errors. + - `drop_failed_servers=True` removes failed servers from `active_servers` + (recommended). If False, `active_servers` will still include all servers. + - `strict=True` raises on the first connection failure. If False, failures + are recorded and the run can proceed with the remaining servers. + - `reconnect(failed_only=True)` retries failed servers and refreshes + `active_servers`. + - `connect_in_parallel=True` uses a dedicated worker task per server to + allow concurrent connects while preserving task affinity for cleanup. + """ + + def __init__( + self, + servers: Iterable[MCPServer], + *, + connect_timeout_seconds: float | None = 10.0, + cleanup_timeout_seconds: float | None = 10.0, + drop_failed_servers: bool = True, + strict: bool = False, + suppress_cancelled_error: bool = True, + connect_in_parallel: bool = False, + ) -> None: + self._all_servers = list(servers) + self._active_servers = list(servers) + self.connect_timeout_seconds = connect_timeout_seconds + self.cleanup_timeout_seconds = cleanup_timeout_seconds + self.drop_failed_servers = drop_failed_servers + self.strict = strict + self.suppress_cancelled_error = suppress_cancelled_error + self.connect_in_parallel = connect_in_parallel + self._workers: dict[MCPServer, _ServerWorker] = {} + + self.failed_servers: list[MCPServer] = [] + self._failed_server_set: set[MCPServer] = set() + self._connected_servers: set[MCPServer] = set() + self.errors: dict[MCPServer, BaseException] = {} + + @property + def active_servers(self) -> list[MCPServer]: + """Return the active MCP servers after connection attempts.""" + return list(self._active_servers) + + @property + def all_servers(self) -> list[MCPServer]: + """Return all MCP servers managed by this instance.""" + return list(self._all_servers) + + async def __aenter__(self) -> MCPServerManager: + await self.connect_all() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool | None: + await self.cleanup_all() + return None + + async def connect_all(self) -> list[MCPServer]: + """Connect all servers in order and return the active list.""" + previous_connected_servers = set(self._connected_servers) + previous_active_servers = list(self._active_servers) + self.failed_servers = [] + self._failed_server_set = set() + self.errors = {} + + servers_to_connect = self._servers_to_connect(self._all_servers) + connected_servers: list[MCPServer] = [] + try: + if self.connect_in_parallel: + await self._connect_all_parallel(servers_to_connect) + else: + for server in servers_to_connect: + await self._attempt_connect(server) + if server not in self._failed_server_set: + connected_servers.append(server) + except BaseException: + if self.connect_in_parallel: + await self._cleanup_servers(servers_to_connect) + else: + servers_to_cleanup = self._unique_servers( + [*connected_servers, *self.failed_servers] + ) + await self._cleanup_servers(servers_to_cleanup) + if self.drop_failed_servers: + self._active_servers = [ + server for server in self._all_servers if server in previous_connected_servers + ] + else: + self._active_servers = previous_active_servers + raise + + self._refresh_active_servers() + + return self._active_servers + + async def reconnect(self, *, failed_only: bool = True) -> list[MCPServer]: + """Reconnect servers and return the active list. + + Args: + failed_only: If True, only retry servers that previously failed. + If False, cleanup and retry all servers. + """ + if failed_only: + servers_to_retry = self._unique_servers(self.failed_servers) + else: + await self.cleanup_all() + servers_to_retry = list(self._all_servers) + self.failed_servers = [] + self._failed_server_set = set() + self.errors = {} + + servers_to_retry = self._servers_to_connect(servers_to_retry) + try: + if self.connect_in_parallel: + await self._connect_all_parallel(servers_to_retry) + else: + for server in servers_to_retry: + await self._attempt_connect(server) + finally: + self._refresh_active_servers() + return self._active_servers + + async def cleanup_all(self) -> None: + """Cleanup all servers in reverse order.""" + for server in reversed(self._all_servers): + try: + await self._cleanup_server(server) + except asyncio.CancelledError as exc: + if not self.suppress_cancelled_error: + raise + logger.debug(f"Cleanup cancelled for MCP server '{server.name}': {exc}") + self.errors[server] = exc + except Exception as exc: + logger.exception(f"Failed to cleanup MCP server '{server.name}': {exc}") + self.errors[server] = exc + + async def _run_with_timeout( + self, func: Callable[[], Awaitable[Any]], timeout_seconds: float | None + ) -> None: + await _run_with_timeout_in_task(func, timeout_seconds) + + async def _attempt_connect( + self, server: MCPServer, *, raise_on_error: bool | None = None + ) -> None: + if raise_on_error is None: + raise_on_error = self.strict + try: + await self._run_connect(server) + self._connected_servers.add(server) + if server in self.failed_servers: + self._remove_failed_server(server) + self.errors.pop(server, None) + except asyncio.CancelledError as exc: + if not self.suppress_cancelled_error: + raise + self._record_failure(server, exc, phase="connect") + except Exception as exc: + self._record_failure(server, exc, phase="connect") + if raise_on_error: + raise + except BaseException as exc: + self._record_failure(server, exc, phase="connect") + raise + + def _refresh_active_servers(self) -> None: + if self.drop_failed_servers: + failed = set(self._failed_server_set) + self._active_servers = [server for server in self._all_servers if server not in failed] + else: + self._active_servers = list(self._all_servers) + + def _record_failure(self, server: MCPServer, exc: BaseException, phase: str) -> None: + logger.exception(f"Failed to {phase} MCP server '{server.name}': {exc}") + if server not in self._failed_server_set: + self.failed_servers.append(server) + self._failed_server_set.add(server) + self.errors[server] = exc + + async def _run_connect(self, server: MCPServer) -> None: + if self.connect_in_parallel: + worker = self._get_worker(server) + await worker.connect() + else: + await self._run_with_timeout(server.connect, self.connect_timeout_seconds) + + async def _cleanup_server(self, server: MCPServer) -> None: + if self.connect_in_parallel and server in self._workers: + worker = self._workers[server] + if worker.is_done: + self._workers.pop(server, None) + self._connected_servers.discard(server) + return + try: + await worker.cleanup() + finally: + self._workers.pop(server, None) + self._connected_servers.discard(server) + return + try: + await self._run_with_timeout(server.cleanup, self.cleanup_timeout_seconds) + finally: + self._connected_servers.discard(server) + + async def _cleanup_servers(self, servers: Iterable[MCPServer]) -> None: + for server in reversed(list(servers)): + try: + await self._cleanup_server(server) + except asyncio.CancelledError as exc: + if not self.suppress_cancelled_error: + raise + logger.debug(f"Cleanup cancelled for MCP server '{server.name}': {exc}") + self.errors[server] = exc + except Exception as exc: + logger.exception(f"Failed to cleanup MCP server '{server.name}': {exc}") + self.errors[server] = exc + + async def _connect_all_parallel(self, servers: list[MCPServer]) -> None: + tasks = [ + asyncio.create_task(self._attempt_connect(server, raise_on_error=False)) + for server in servers + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + if not self.suppress_cancelled_error: + for result in results: + if isinstance(result, asyncio.CancelledError): + raise result + for result in results: + if isinstance(result, BaseException) and not isinstance(result, asyncio.CancelledError): + raise result + if self.strict and self.failed_servers: + first_failure = None + if self.suppress_cancelled_error: + for server in self.failed_servers: + error = self.errors.get(server) + if error is None or isinstance(error, asyncio.CancelledError): + continue + first_failure = server + break + else: + first_failure = self.failed_servers[0] + if first_failure is not None: + error = self.errors.get(first_failure) + if error is not None: + raise error + raise RuntimeError(f"Failed to connect MCP server '{first_failure.name}'") + + def _get_worker(self, server: MCPServer) -> _ServerWorker: + worker = self._workers.get(server) + if worker is None or worker.is_done: + worker = _ServerWorker( + server=server, + connect_timeout_seconds=self.connect_timeout_seconds, + cleanup_timeout_seconds=self.cleanup_timeout_seconds, + ) + self._workers[server] = worker + return worker + + def _remove_failed_server(self, server: MCPServer) -> None: + if server in self._failed_server_set: + self._failed_server_set.remove(server) + self.failed_servers = [ + failed_server for failed_server in self.failed_servers if failed_server != server + ] + + def _servers_to_connect(self, servers: Iterable[MCPServer]) -> list[MCPServer]: + unique = self._unique_servers(servers) + if not self._connected_servers: + return unique + return [server for server in unique if server not in self._connected_servers] + + @staticmethod + def _unique_servers(servers: Iterable[MCPServer]) -> list[MCPServer]: + seen: set[MCPServer] = set() + unique: list[MCPServer] = [] + for server in servers: + if server not in seen: + seen.add(server) + unique.append(server) + return unique diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 015b5b6f76..ad1371a244 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -451,6 +451,9 @@ async def cleanup(self): try: await self.exit_stack.aclose() + except asyncio.CancelledError as e: + logger.debug(f"Cleanup cancelled for MCP server '{self.name}': {e}") + raise except BaseExceptionGroup as eg: # Extract HTTP errors from ExceptionGroup raised during cleanup # This happens when background tasks fail (e.g., HTTP errors) diff --git a/tests/mcp/test_mcp_server_manager.py b/tests/mcp/test_mcp_server_manager.py new file mode 100644 index 0000000000..b52aeb5311 --- /dev/null +++ b/tests/mcp/test_mcp_server_manager.py @@ -0,0 +1,480 @@ +import asyncio +from typing import Any, cast + +import pytest +from mcp.types import CallToolResult, GetPromptResult, ListPromptsResult, Tool as MCPTool + +from agents.mcp import MCPServer, MCPServerManager +from agents.run_context import RunContextWrapper + + +class TaskBoundServer(MCPServer): + def __init__(self) -> None: + super().__init__() + self._connect_task: asyncio.Task[object] | None = None + self.cleaned = False + + @property + def name(self) -> str: + return "task-bound" + + async def connect(self) -> None: + self._connect_task = asyncio.current_task() + + async def cleanup(self) -> None: + if self._connect_task is None: + raise RuntimeError("Server was not connected") + if asyncio.current_task() is not self._connect_task: + raise RuntimeError("Attempted to exit cancel scope in a different task") + self.cleaned = True + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + +class FlakyServer(MCPServer): + def __init__(self, failures: int) -> None: + super().__init__() + self.failures_remaining = failures + self.connect_calls = 0 + + @property + def name(self) -> str: + return "flaky" + + async def connect(self) -> None: + self.connect_calls += 1 + if self.failures_remaining > 0: + self.failures_remaining -= 1 + raise RuntimeError("connect failed") + + async def cleanup(self) -> None: + return None + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + +class CleanupAwareServer(MCPServer): + def __init__(self) -> None: + super().__init__() + self.connect_calls = 0 + self.cleanup_calls = 0 + + @property + def name(self) -> str: + return "cleanup-aware" + + async def connect(self) -> None: + if self.connect_calls > self.cleanup_calls: + raise RuntimeError("connect called without cleanup") + self.connect_calls += 1 + + async def cleanup(self) -> None: + self.cleanup_calls += 1 + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + +class CancelledServer(MCPServer): + @property + def name(self) -> str: + return "cancelled" + + async def connect(self) -> None: + raise asyncio.CancelledError() + + async def cleanup(self) -> None: + return None + + async def list_tools( + self, run_context: RunContextWrapper[Any] | None = None, agent: Any | None = None + ) -> list[MCPTool]: + raise NotImplementedError + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: + raise NotImplementedError + + async def list_prompts(self) -> ListPromptsResult: + raise NotImplementedError + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> GetPromptResult: + raise NotImplementedError + + +class FailingTaskBoundServer(TaskBoundServer): + @property + def name(self) -> str: + return "failing-task-bound" + + async def connect(self) -> None: + await super().connect() + raise RuntimeError("connect failed") + + +class FatalError(BaseException): + pass + + +class FatalTaskBoundServer(TaskBoundServer): + @property + def name(self) -> str: + return "fatal-task-bound" + + async def connect(self) -> None: + await super().connect() + raise FatalError("fatal connect failed") + + +class CleanupFailingServer(TaskBoundServer): + @property + def name(self) -> str: + return "cleanup-failing" + + async def cleanup(self) -> None: + await super().cleanup() + raise RuntimeError("cleanup failed") + + +@pytest.mark.asyncio +async def test_manager_keeps_connect_and_cleanup_in_same_task() -> None: + server = TaskBoundServer() + + async with MCPServerManager([server]) as manager: + assert manager.active_servers == [server] + + assert server.cleaned is True + + +@pytest.mark.asyncio +async def test_manager_connects_in_worker_tasks_when_parallel() -> None: + server = TaskBoundServer() + + async with MCPServerManager([server], connect_in_parallel=True) as manager: + assert manager.active_servers == [server] + assert server._connect_task is not None + assert server._connect_task is not asyncio.current_task() + + assert server.cleaned is True + + +@pytest.mark.asyncio +async def test_cross_task_cleanup_raises_without_manager() -> None: + server = TaskBoundServer() + + connect_task = asyncio.create_task(server.connect()) + await connect_task + + with pytest.raises(RuntimeError, match="cancel scope"): + await server.cleanup() + + +@pytest.mark.asyncio +async def test_manager_reconnect_failed_only() -> None: + server = FlakyServer(failures=1) + + async with MCPServerManager([server]) as manager: + assert manager.active_servers == [] + assert manager.failed_servers == [server] + + await manager.reconnect() + assert manager.active_servers == [server] + assert manager.failed_servers == [] + + +@pytest.mark.asyncio +async def test_manager_reconnect_deduplicates_failures() -> None: + server = FlakyServer(failures=2) + + async with MCPServerManager([server], connect_in_parallel=True) as manager: + assert manager.active_servers == [] + assert manager.failed_servers == [server] + assert server.connect_calls == 1 + + await manager.reconnect() + assert manager.active_servers == [] + assert manager.failed_servers == [server] + assert server.connect_calls == 2 + + await manager.reconnect() + assert manager.active_servers == [server] + assert manager.failed_servers == [] + assert server.connect_calls == 3 + + +@pytest.mark.asyncio +async def test_manager_connect_all_retries_all_servers() -> None: + server = FlakyServer(failures=1) + manager = MCPServerManager([server]) + try: + await manager.connect_all() + assert manager.active_servers == [] + assert manager.failed_servers == [server] + assert server.connect_calls == 1 + + await manager.connect_all() + assert manager.active_servers == [server] + assert manager.failed_servers == [] + assert server.connect_calls == 2 + finally: + await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_manager_connect_all_is_idempotent() -> None: + server = CleanupAwareServer() + + async with MCPServerManager([server]) as manager: + assert server.connect_calls == 1 + await manager.connect_all() + + +@pytest.mark.asyncio +async def test_manager_reconnect_all_avoids_duplicate_connections() -> None: + server = CleanupAwareServer() + + async with MCPServerManager([server]) as manager: + assert server.connect_calls == 1 + await manager.reconnect(failed_only=False) + + +@pytest.mark.asyncio +async def test_manager_strict_reconnect_refreshes_active_servers() -> None: + server_a = FlakyServer(failures=1) + server_b = FlakyServer(failures=2) + + async with MCPServerManager([server_a, server_b]) as manager: + assert manager.active_servers == [] + + manager.strict = True + with pytest.raises(RuntimeError, match="connect failed"): + await manager.reconnect() + + assert manager.active_servers == [server_a] + assert manager.failed_servers == [server_b] + + +@pytest.mark.asyncio +async def test_manager_strict_connect_preserves_existing_active_servers() -> None: + connected_server = TaskBoundServer() + failing_server = FlakyServer(failures=2) + manager = MCPServerManager([connected_server, failing_server]) + try: + await manager.connect_all() + assert manager.active_servers == [connected_server] + assert manager.failed_servers == [failing_server] + + manager.strict = True + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert manager.active_servers == [connected_server] + assert manager.failed_servers == [failing_server] + finally: + await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_manager_strict_connect_cleans_up_connected_servers() -> None: + connected_server = TaskBoundServer() + failing_server = FlakyServer(failures=1) + manager = MCPServerManager([connected_server, failing_server], strict=True) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert connected_server.cleaned is True + assert manager.active_servers == [] + + +@pytest.mark.asyncio +async def test_manager_strict_connect_cleans_up_failed_server() -> None: + failing_server = FailingTaskBoundServer() + manager = MCPServerManager([failing_server], strict=True) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert failing_server.cleaned is True + + +@pytest.mark.asyncio +async def test_manager_strict_connect_parallel_cleans_up_failed_server() -> None: + failing_server = FailingTaskBoundServer() + manager = MCPServerManager([failing_server], strict=True, connect_in_parallel=True) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert failing_server.cleaned is True + + +@pytest.mark.asyncio +async def test_manager_strict_connect_parallel_cleans_up_workers() -> None: + connected_server = TaskBoundServer() + failing_server = FailingTaskBoundServer() + manager = MCPServerManager( + [connected_server, failing_server], strict=True, connect_in_parallel=True + ) + + with pytest.raises(RuntimeError, match="connect failed"): + await manager.connect_all() + + assert connected_server.cleaned is True + assert failing_server.cleaned is True + assert manager._workers == {} + + +@pytest.mark.asyncio +async def test_manager_parallel_cleanup_clears_worker_on_failure() -> None: + server = CleanupFailingServer() + manager = MCPServerManager([server], connect_in_parallel=True) + await manager.connect_all() + await manager.cleanup_all() + + assert server not in manager._workers + assert server not in manager._connected_servers + + +@pytest.mark.asyncio +async def test_manager_parallel_cleanup_drops_worker_after_error() -> None: + class HangingCleanupWorker: + def __init__(self) -> None: + self.cleanup_calls = 0 + + @property + def is_done(self) -> bool: + return False + + async def cleanup(self) -> None: + self.cleanup_calls += 1 + raise RuntimeError("cleanup failed") + + server = FlakyServer(failures=0) + manager = MCPServerManager([server], connect_in_parallel=True) + manager._workers[server] = cast(Any, HangingCleanupWorker()) + + await manager.cleanup_all() + + assert manager._workers == {} + + +@pytest.mark.asyncio +async def test_manager_parallel_suppresses_cancelled_error_in_strict_mode() -> None: + server = CancelledServer() + manager = MCPServerManager([server], connect_in_parallel=True, strict=True) + try: + await manager.connect_all() + assert manager.active_servers == [] + assert manager.failed_servers == [server] + finally: + await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_manager_parallel_propagates_cancelled_error_when_unsuppressed() -> None: + server = CancelledServer() + manager = MCPServerManager([server], connect_in_parallel=True, suppress_cancelled_error=False) + try: + with pytest.raises(asyncio.CancelledError): + await manager.connect_all() + finally: + await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_manager_sequential_propagates_base_exception() -> None: + server = FatalTaskBoundServer() + manager = MCPServerManager([server]) + + with pytest.raises(FatalError, match="fatal connect failed"): + await manager.connect_all() + + assert server.cleaned is True + assert manager.failed_servers == [server] + + +@pytest.mark.asyncio +async def test_manager_parallel_propagates_base_exception() -> None: + server = FatalTaskBoundServer() + manager = MCPServerManager([server], connect_in_parallel=True) + + with pytest.raises(FatalError, match="fatal connect failed"): + await manager.connect_all() + + assert server.cleaned is True + assert manager._workers == {} + + +@pytest.mark.asyncio +async def test_manager_parallel_prefers_cancelled_error_when_unsuppressed() -> None: + cancelled_server = CancelledServer() + fatal_server = FatalTaskBoundServer() + manager = MCPServerManager( + [fatal_server, cancelled_server], + connect_in_parallel=True, + suppress_cancelled_error=False, + ) + try: + with pytest.raises(asyncio.CancelledError): + await manager.connect_all() + finally: + await manager.cleanup_all() + + +@pytest.mark.asyncio +async def test_manager_cleanup_runs_on_cancelled_error_during_connect() -> None: + server = CleanupAwareServer() + cancelled_server = CancelledServer() + manager = MCPServerManager( + [server, cancelled_server], + suppress_cancelled_error=False, + ) + try: + with pytest.raises(asyncio.CancelledError): + await manager.connect_all() + assert server.cleanup_calls == 1 + finally: + await manager.cleanup_all()