Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 120 additions & 4 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import asyncio
import inspect
import sys
from collections.abc import Awaitable
from collections.abc import AsyncGenerator, Awaitable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast

import anyio
import httpx

if sys.version_info < (3, 11):
Expand All @@ -19,7 +20,11 @@
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
from mcp.client.session import MessageHandlerFnT
from mcp.client.sse import sse_client
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.client.streamable_http import (
GetSessionIdCallback,
StreamableHTTPTransport,
streamablehttp_client,
)
from mcp.shared.exceptions import McpError
from mcp.shared.message import SessionMessage
from mcp.types import (
Expand Down Expand Up @@ -71,6 +76,101 @@ class RequireApprovalObject(TypedDict, total=False):
T = TypeVar("T")


def _create_default_streamable_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
kwargs: dict[str, Any] = {"follow_redirects": True}
if timeout is not None:
kwargs["timeout"] = timeout
if headers is not None:
kwargs["headers"] = headers
if auth is not None:
kwargs["auth"] = auth
return httpx.AsyncClient(**kwargs)


class _InitializedNotificationTolerantStreamableHTTPTransport(StreamableHTTPTransport):
async def _handle_post_request(self, ctx: Any) -> None:
message = ctx.session_message.message
if not self._is_initialized_notification(message):
await super()._handle_post_request(ctx)
return

try:
await super()._handle_post_request(ctx)
except httpx.HTTPError:
logger.warning(
"Ignoring initialized notification HTTP failure",
exc_info=True,
)
return


@asynccontextmanager
async def _streamablehttp_client_with_transport(
url: str,
*,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
httpx_client_factory: HttpClientFactory = _create_default_streamable_http_client,
auth: httpx.Auth | None = None,
transport_factory: Callable[[str], StreamableHTTPTransport] = StreamableHTTPTransport,
) -> AsyncGenerator[MCPStreamTransport, None]:
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
sse_read_timeout_seconds = (
sse_read_timeout.total_seconds()
if isinstance(sse_read_timeout, timedelta)
else sse_read_timeout
)

client = httpx_client_factory(
headers=headers,
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds),
auth=auth,
)
transport = transport_factory(url)
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](
0
)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with client:
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

tg.start_soon(
transport.post_writer,
client,
write_stream_reader,
read_stream_writer,
write_stream,
start_get_stream,
tg,
)

try:
yield (
read_stream,
write_stream,
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()


class _SharedSessionRequestNeedsIsolation(Exception):
"""Raised when a shared-session request should be retried on an isolated session."""

Expand Down Expand Up @@ -1160,6 +1260,14 @@ class MCPServerStreamableHttpParams(TypedDict):
transport.
"""

ignore_initialized_notification_failure: NotRequired[bool]
"""Whether to ignore failures when sending the best-effort
``notifications/initialized`` POST.

Defaults to ``False``. When set to ``True``, initialized-notification failures are
logged and ignored so subsequent requests on the same transport can continue.
"""


class MCPServerStreamableHttp(_MCPServerWithClientSession):
"""MCP server implementation that uses the Streamable HTTP transport. See the [spec]
Expand Down Expand Up @@ -1250,8 +1358,16 @@ def create_streams(
"sse_read_timeout": self.params.get("sse_read_timeout", 60 * 5),
"terminate_on_close": self.params.get("terminate_on_close", True),
}
if "httpx_client_factory" in self.params:
kwargs["httpx_client_factory"] = self.params["httpx_client_factory"]
httpx_client_factory = self.params.get("httpx_client_factory")
if self.params.get("ignore_initialized_notification_failure", False):
return _streamablehttp_client_with_transport(
**kwargs,
httpx_client_factory=httpx_client_factory or _create_default_streamable_http_client,
auth=self.params.get("auth"),
transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport,
)
if httpx_client_factory is not None:
kwargs["httpx_client_factory"] = httpx_client_factory
if "auth" in self.params:
kwargs["auth"] = self.params["auth"]
return streamablehttp_client(**kwargs)
Expand Down
193 changes: 193 additions & 0 deletions tests/mcp/test_streamable_http_client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@

from __future__ import annotations

import base64
from unittest.mock import MagicMock, patch

import httpx
import pytest
from anyio import create_memory_object_stream
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest

from agents.mcp import MCPServerStreamableHttp
from agents.mcp.server import (
_create_default_streamable_http_client,
_InitializedNotificationTolerantStreamableHTTPTransport,
_streamablehttp_client_with_transport,
)


class TestMCPServerStreamableHttpClientFactory:
Expand Down Expand Up @@ -247,3 +256,187 @@ def comprehensive_factory(
terminate_on_close=False,
httpx_client_factory=comprehensive_factory,
)


@pytest.mark.asyncio
async def test_initialized_notification_failure_returns_synthetic_success():
async def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(503, request=request)

transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp")
read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0)
client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
try:
ctx = MagicMock()
ctx.client = client
ctx.read_stream_writer = read_stream_writer
ctx.session_message = SessionMessage(
JSONRPCMessage(
JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
params={},
)
)
)

await transport._handle_post_request(ctx)
finally:
await client.aclose()
await read_stream_writer.aclose()


@pytest.mark.asyncio
async def test_initialized_notification_transport_exception_returns_synthetic_success():
async def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("boom", request=request)

transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp")
read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0)
client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
try:
ctx = MagicMock()
ctx.client = client
ctx.read_stream_writer = read_stream_writer
ctx.session_message = SessionMessage(
JSONRPCMessage(
JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
params={},
)
)
)

await transport._handle_post_request(ctx)
finally:
await client.aclose()
await read_stream_writer.aclose()


@pytest.mark.asyncio
async def test_streamable_http_server_passes_ignore_initialized_notification_failure():
with patch("agents.mcp.server._streamablehttp_client_with_transport") as mock_client:
mock_client.return_value = MagicMock()

server = MCPServerStreamableHttp(
params={
"url": "http://localhost:8000/mcp",
"ignore_initialized_notification_failure": True,
}
)

server.create_streams()

kwargs = mock_client.call_args.kwargs
assert kwargs["url"] == "http://localhost:8000/mcp"
assert kwargs["headers"] is None
assert kwargs["timeout"] == 5
assert kwargs["sse_read_timeout"] == 300
assert kwargs["terminate_on_close"] is True
assert (
kwargs["transport_factory"] is _InitializedNotificationTolerantStreamableHTTPTransport
)


@pytest.mark.asyncio
async def test_transport_preserves_non_initialized_failures():
async def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("boom", request=request)

transport = _InitializedNotificationTolerantStreamableHTTPTransport("https://example.test/mcp")
read_stream_writer, _ = create_memory_object_stream[SessionMessage | Exception](0)
client = httpx.AsyncClient(transport=httpx.MockTransport(handler))
try:
ctx = MagicMock()
ctx.client = client
ctx.read_stream_writer = read_stream_writer
ctx.session_message = SessionMessage(
JSONRPCMessage(
JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="tools/list",
params={},
)
)
)

with pytest.raises(httpx.ConnectError):
await transport._handle_post_request(ctx)
finally:
await client.aclose()
await read_stream_writer.aclose()


@pytest.mark.asyncio
async def test_stream_client_preserves_custom_factory_headers_timeout_and_auth():
seen: dict[str, object] = {}

class RecordingAuth(httpx.Auth):
def auth_flow(self, request: httpx.Request):
request.headers["Authorization"] = f"Basic {base64.b64encode(b'user:pass').decode()}"
yield request

async def handler(request: httpx.Request) -> httpx.Response:
seen["request_headers"] = dict(request.headers)
return httpx.Response(200, request=request)

def base_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
seen["factory_headers"] = headers
seen["factory_timeout"] = timeout
seen["factory_auth"] = auth
return httpx.AsyncClient(
headers=headers,
timeout=timeout,
auth=auth,
transport=httpx.MockTransport(handler),
)

timeout = httpx.Timeout(12.0)
auth = RecordingAuth()
async with _streamablehttp_client_with_transport(
"https://example.test/mcp",
headers={"X-Test": "value"},
timeout=12.0,
sse_read_timeout=30.0,
httpx_client_factory=base_factory,
auth=auth,
transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport,
):
pass

assert seen["factory_headers"] == {"X-Test": "value"}
seen_timeout = seen["factory_timeout"]
assert isinstance(seen_timeout, httpx.Timeout)
assert seen_timeout.connect == timeout.connect
assert seen_timeout.read == 30.0
assert seen_timeout.write == timeout.write
assert seen_timeout.pool == timeout.pool
assert seen["factory_auth"] is auth


@pytest.mark.asyncio
async def test_default_streamable_http_client_matches_expected_defaults():
timeout = httpx.Timeout(12.0)
auth = httpx.BasicAuth("user", "pass")

client = _create_default_streamable_http_client(
headers={"X-Test": "value"},
timeout=timeout,
auth=auth,
)
try:
assert client.headers["X-Test"] == "value"
assert client.timeout.connect == timeout.connect
assert client.timeout.read == timeout.read
assert client.timeout.write == timeout.write
assert client.timeout.pool == timeout.pool
assert client.auth is auth
assert client.follow_redirects is True
finally:
await client.aclose()
Loading