diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 363d6995b4..abbb5e087f 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -4,12 +4,13 @@ import asyncio import inspect import sys -from collections.abc import Awaitable +from collections.abc import AsyncGenerator, Awaitable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast +import anyio import httpx if sys.version_info < (3, 11): @@ -19,7 +20,11 @@ from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp.client.session import MessageHandlerFnT from mcp.client.sse import sse_client -from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.streamable_http import ( + GetSessionIdCallback, + StreamableHTTPTransport, + streamablehttp_client, +) from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.types import ( @@ -71,6 +76,101 @@ class RequireApprovalObject(TypedDict, total=False): T = TypeVar("T") +def _create_default_streamable_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, +) -> httpx.AsyncClient: + kwargs: dict[str, Any] = {"follow_redirects": True} + if timeout is not None: + kwargs["timeout"] = timeout + if headers is not None: + kwargs["headers"] = headers + if auth is not None: + kwargs["auth"] = auth + return httpx.AsyncClient(**kwargs) + + +class _InitializedNotificationTolerantStreamableHTTPTransport(StreamableHTTPTransport): + async def _handle_post_request(self, ctx: Any) -> None: + message = ctx.session_message.message + if not self._is_initialized_notification(message): + await super()._handle_post_request(ctx) + return + + try: + await super()._handle_post_request(ctx) + except httpx.HTTPError: + logger.warning( + "Ignoring initialized notification HTTP failure", + exc_info=True, + ) + return + + +@asynccontextmanager +async def _streamablehttp_client_with_transport( + url: str, + *, + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, + terminate_on_close: bool = True, + httpx_client_factory: HttpClientFactory = _create_default_streamable_http_client, + auth: httpx.Auth | None = None, + transport_factory: Callable[[str], StreamableHTTPTransport] = StreamableHTTPTransport, +) -> AsyncGenerator[MCPStreamTransport, None]: + timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + sse_read_timeout_seconds = ( + sse_read_timeout.total_seconds() + if isinstance(sse_read_timeout, timedelta) + else sse_read_timeout + ) + + client = httpx_client_factory( + headers=headers, + timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds), + auth=auth, + ) + transport = transport_factory(url) + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception]( + 0 + ) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + + async with client: + async with anyio.create_task_group() as tg: + try: + logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + + def start_get_stream() -> None: + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + tg, + ) + + try: + yield ( + read_stream, + write_stream, + transport.get_session_id, + ) + finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + class _SharedSessionRequestNeedsIsolation(Exception): """Raised when a shared-session request should be retried on an isolated session.""" @@ -1160,6 +1260,14 @@ class MCPServerStreamableHttpParams(TypedDict): transport. """ + ignore_initialized_notification_failure: NotRequired[bool] + """Whether to ignore failures when sending the best-effort + ``notifications/initialized`` POST. + + Defaults to ``False``. When set to ``True``, initialized-notification failures are + logged and ignored so subsequent requests on the same transport can continue. + """ + class MCPServerStreamableHttp(_MCPServerWithClientSession): """MCP server implementation that uses the Streamable HTTP transport. See the [spec] @@ -1250,8 +1358,16 @@ def create_streams( "sse_read_timeout": self.params.get("sse_read_timeout", 60 * 5), "terminate_on_close": self.params.get("terminate_on_close", True), } - if "httpx_client_factory" in self.params: - kwargs["httpx_client_factory"] = self.params["httpx_client_factory"] + httpx_client_factory = self.params.get("httpx_client_factory") + if self.params.get("ignore_initialized_notification_failure", False): + return _streamablehttp_client_with_transport( + **kwargs, + httpx_client_factory=httpx_client_factory or _create_default_streamable_http_client, + auth=self.params.get("auth"), + transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport, + ) + if httpx_client_factory is not None: + kwargs["httpx_client_factory"] = httpx_client_factory if "auth" in self.params: kwargs["auth"] = self.params["auth"] return streamablehttp_client(**kwargs) diff --git a/tests/mcp/test_streamable_http_client_factory.py b/tests/mcp/test_streamable_http_client_factory.py index cf931a3011..068407a2fd 100644 --- a/tests/mcp/test_streamable_http_client_factory.py +++ b/tests/mcp/test_streamable_http_client_factory.py @@ -2,12 +2,21 @@ from __future__ import annotations +import base64 from unittest.mock import MagicMock, patch import httpx import pytest +from anyio import create_memory_object_stream +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest from agents.mcp import MCPServerStreamableHttp +from agents.mcp.server import ( + _create_default_streamable_http_client, + _InitializedNotificationTolerantStreamableHTTPTransport, + _streamablehttp_client_with_transport, +) class TestMCPServerStreamableHttpClientFactory: @@ -247,3 +256,187 @@ def comprehensive_factory( terminate_on_close=False, httpx_client_factory=comprehensive_factory, ) + + +@pytest.mark.asyncio +async def test_initialized_notification_failure_returns_synthetic_success(): + async def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(503, request=request) + + transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp") + read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0) + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + try: + ctx = MagicMock() + ctx.client = client + ctx.read_stream_writer = read_stream_writer + ctx.session_message = SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + params={}, + ) + ) + ) + + await transport._handle_post_request(ctx) + finally: + await client.aclose() + await read_stream_writer.aclose() + + +@pytest.mark.asyncio +async def test_initialized_notification_transport_exception_returns_synthetic_success(): + async def handler(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("boom", request=request) + + transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp") + read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0) + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + try: + ctx = MagicMock() + ctx.client = client + ctx.read_stream_writer = read_stream_writer + ctx.session_message = SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + params={}, + ) + ) + ) + + await transport._handle_post_request(ctx) + finally: + await client.aclose() + await read_stream_writer.aclose() + + +@pytest.mark.asyncio +async def test_streamable_http_server_passes_ignore_initialized_notification_failure(): + with patch("agents.mcp.server._streamablehttp_client_with_transport") as mock_client: + mock_client.return_value = MagicMock() + + server = MCPServerStreamableHttp( + params={ + "url": "http://localhost:8000/mcp", + "ignore_initialized_notification_failure": True, + } + ) + + server.create_streams() + + kwargs = mock_client.call_args.kwargs + assert kwargs["url"] == "http://localhost:8000/mcp" + assert kwargs["headers"] is None + assert kwargs["timeout"] == 5 + assert kwargs["sse_read_timeout"] == 300 + assert kwargs["terminate_on_close"] is True + assert ( + kwargs["transport_factory"] is _InitializedNotificationTolerantStreamableHTTPTransport + ) + + +@pytest.mark.asyncio +async def test_transport_preserves_non_initialized_failures(): + async def handler(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("boom", request=request) + + transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp") + read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0) + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + try: + ctx = MagicMock() + ctx.client = client + ctx.read_stream_writer = read_stream_writer + ctx.session_message = SessionMessage( + JSONRPCMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="tools/list", + params={}, + ) + ) + ) + + with pytest.raises(httpx.ConnectError): + await transport._handle_post_request(ctx) + finally: + await client.aclose() + await read_stream_writer.aclose() + + +@pytest.mark.asyncio +async def test_stream_client_preserves_custom_factory_headers_timeout_and_auth(): + seen: dict[str, object] = {} + + class RecordingAuth(httpx.Auth): + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Basic {base64.b64encode(b'user:pass').decode()}" + yield request + + async def handler(request: httpx.Request) -> httpx.Response: + seen["request_headers"] = dict(request.headers) + return httpx.Response(200, request=request) + + def base_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + seen["factory_headers"] = headers + seen["factory_timeout"] = timeout + seen["factory_auth"] = auth + return httpx.AsyncClient( + headers=headers, + timeout=timeout, + auth=auth, + transport=httpx.MockTransport(handler), + ) + + timeout = httpx.Timeout(12.0) + auth = RecordingAuth() + async with _streamablehttp_client_with_transport( + "https://example.test/mcp", + headers={"X-Test": "value"}, + timeout=12.0, + sse_read_timeout=30.0, + httpx_client_factory=base_factory, + auth=auth, + transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport, + ): + pass + + assert seen["factory_headers"] == {"X-Test": "value"} + seen_timeout = seen["factory_timeout"] + assert isinstance(seen_timeout, httpx.Timeout) + assert seen_timeout.connect == timeout.connect + assert seen_timeout.read == 30.0 + assert seen_timeout.write == timeout.write + assert seen_timeout.pool == timeout.pool + assert seen["factory_auth"] is auth + + +@pytest.mark.asyncio +async def test_default_streamable_http_client_matches_expected_defaults(): + timeout = httpx.Timeout(12.0) + auth = httpx.BasicAuth("user", "pass") + + client = _create_default_streamable_http_client( + headers={"X-Test": "value"}, + timeout=timeout, + auth=auth, + ) + try: + assert client.headers["X-Test"] == "value" + assert client.timeout.connect == timeout.connect + assert client.timeout.read == timeout.read + assert client.timeout.write == timeout.write + assert client.timeout.pool == timeout.pool + assert client.auth is auth + assert client.follow_redirects is True + finally: + await client.aclose()