Skip to content

Commit 9e277d1

Browse files
committed
test: convert WebSocket tests to in-memory transport
test_ws.py used the subprocess + TCP port pattern that races under pytest-xdist: a worker allocates a port with socket.bind(0), releases it, then spawns a uvicorn subprocess hoping to rebind. Between release and rebind, another worker can claim the port, causing the WS client to connect to an unrelated server (observed: HTTP 403 Forbidden on the WebSocket upgrade). Three of the four tests here verify transport-agnostic MCP semantics (read_resource happy path, MCPError propagation, session recovery after client-side timeout). These now use the in-memory Client transport — no network, no subprocess, no race. The fourth test (test_ws_client_basic_connection) is kept as a smoke test running the real WS stack end-to-end. It uses a new run_uvicorn_in_thread helper that binds port=0 atomically and reads the actual port back from the server's socket — the OS holds the port from bind to shutdown, eliminating the race window entirely. This test alone provides 100% coverage of src/mcp/client/websocket.py. Also removed dead handler code (list_tools/call_tool were never exercised) and the no-longer-needed pragma: no cover annotations on the read_resource handler (it now runs in-process).
1 parent 62eb08e commit 9e277d1

File tree

2 files changed

+136
-159
lines changed

2 files changed

+136
-159
lines changed

tests/shared/test_ws.py

Lines changed: 70 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,210 +1,121 @@
1-
import multiprocessing
2-
import socket
3-
from collections.abc import AsyncGenerator, Generator
1+
"""Tests for the WebSocket transport.
2+
3+
The smoke test (``test_ws_client_basic_connection``) runs the full WS stack
4+
end-to-end over a real TCP connection and is what provides coverage of
5+
``src/mcp/client/websocket.py``.
6+
7+
The remaining tests verify transport-agnostic MCP semantics (error
8+
propagation, client-side timeouts) and use the in-memory ``Client`` transport
9+
to avoid the cost and flakiness of real network servers.
10+
"""
11+
12+
from collections.abc import Generator
413
from urllib.parse import urlparse
514

615
import anyio
716
import pytest
8-
import uvicorn
917
from starlette.applications import Starlette
1018
from starlette.routing import WebSocketRoute
1119
from starlette.websockets import WebSocket
1220

13-
from mcp import MCPError
21+
from mcp import Client, MCPError
1422
from mcp.client.session import ClientSession
1523
from mcp.client.websocket import websocket_client
1624
from mcp.server import Server, ServerRequestContext
1725
from mcp.server.websocket import websocket_server
1826
from mcp.types import (
19-
CallToolRequestParams,
20-
CallToolResult,
2127
EmptyResult,
2228
InitializeResult,
23-
ListToolsResult,
24-
PaginatedRequestParams,
2529
ReadResourceRequestParams,
2630
ReadResourceResult,
27-
TextContent,
2831
TextResourceContents,
29-
Tool,
3032
)
31-
from tests.test_helpers import wait_for_server
33+
from tests.test_helpers import run_uvicorn_in_thread
3234

3335
SERVER_NAME = "test_server_for_WS"
3436

37+
pytestmark = pytest.mark.anyio
3538

36-
@pytest.fixture
37-
def server_port() -> int:
38-
with socket.socket() as s:
39-
s.bind(("127.0.0.1", 0))
40-
return s.getsockname()[1]
4139

40+
# --- WebSocket transport smoke test (real TCP) -------------------------------
4241

43-
@pytest.fixture
44-
def server_url(server_port: int) -> str:
45-
return f"ws://127.0.0.1:{server_port}"
4642

43+
def make_server_app() -> Starlette:
44+
srv = Server(SERVER_NAME)
4745

48-
async def handle_read_resource( # pragma: no cover
49-
ctx: ServerRequestContext, params: ReadResourceRequestParams
50-
) -> ReadResourceResult:
51-
parsed = urlparse(str(params.uri))
52-
if parsed.scheme == "foobar":
53-
return ReadResourceResult(
54-
contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")]
55-
)
56-
elif parsed.scheme == "slow":
57-
await anyio.sleep(2.0)
58-
return ReadResourceResult(
59-
contents=[
60-
TextResourceContents(
61-
uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain"
62-
)
63-
]
64-
)
65-
raise MCPError(code=404, message="OOPS! no resource with that URI was found")
66-
67-
68-
async def handle_list_tools( # pragma: no cover
69-
ctx: ServerRequestContext, params: PaginatedRequestParams | None
70-
) -> ListToolsResult:
71-
return ListToolsResult(
72-
tools=[
73-
Tool(
74-
name="test_tool",
75-
description="A test tool",
76-
input_schema={"type": "object", "properties": {}},
77-
)
78-
]
79-
)
80-
81-
82-
async def handle_call_tool( # pragma: no cover
83-
ctx: ServerRequestContext, params: CallToolRequestParams
84-
) -> CallToolResult:
85-
return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")])
86-
87-
88-
def _create_server() -> Server: # pragma: no cover
89-
return Server(
90-
SERVER_NAME,
91-
on_read_resource=handle_read_resource,
92-
on_list_tools=handle_list_tools,
93-
on_call_tool=handle_call_tool,
94-
)
95-
96-
97-
# Test fixtures
98-
def make_server_app() -> Starlette: # pragma: no cover
99-
"""Create test Starlette app with WebSocket transport"""
100-
server = _create_server()
101-
102-
async def handle_ws(websocket: WebSocket):
46+
async def handle_ws(websocket: WebSocket) -> None:
10347
async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams:
104-
await server.run(streams[0], streams[1], server.create_initialization_options())
105-
106-
app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)])
107-
return app
108-
48+
await srv.run(streams[0], streams[1], srv.create_initialization_options())
10949

110-
def run_server(server_port: int) -> None: # pragma: no cover
111-
app = make_server_app()
112-
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
113-
print(f"starting server on {server_port}")
114-
server.run()
50+
return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)])
11551

11652

117-
@pytest.fixture()
118-
def server(server_port: int) -> Generator[None, None, None]:
119-
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
120-
print("starting process")
121-
proc.start()
122-
123-
# Wait for server to be running
124-
print("waiting for server to start")
125-
wait_for_server(server_port)
126-
127-
yield
128-
129-
print("killing server")
130-
# Signal the server to stop
131-
proc.kill()
132-
proc.join(timeout=2)
133-
if proc.is_alive(): # pragma: no cover
134-
print("server process failed to terminate")
53+
@pytest.fixture
54+
def ws_server_url() -> Generator[str, None, None]:
55+
with run_uvicorn_in_thread(make_server_app()) as base_url:
56+
yield base_url.replace("http://", "ws://") + "/ws"
13557

13658

137-
@pytest.fixture()
138-
async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
139-
"""Create and initialize a WebSocket client session"""
140-
async with websocket_client(server_url + "/ws") as streams:
59+
async def test_ws_client_basic_connection(ws_server_url: str) -> None:
60+
async with websocket_client(ws_server_url) as streams:
14161
async with ClientSession(*streams) as session:
142-
# Test initialization
14362
result = await session.initialize()
14463
assert isinstance(result, InitializeResult)
14564
assert result.server_info.name == SERVER_NAME
14665

147-
# Test ping
14866
ping_result = await session.send_ping()
14967
assert isinstance(ping_result, EmptyResult)
15068

151-
yield session
15269

70+
# --- In-memory tests (transport-agnostic MCP semantics) ----------------------
15371

154-
# Tests
155-
@pytest.mark.anyio
156-
async def test_ws_client_basic_connection(server: None, server_url: str) -> None:
157-
"""Test the WebSocket connection establishment"""
158-
async with websocket_client(server_url + "/ws") as streams:
159-
async with ClientSession(*streams) as session:
160-
# Test initialization
161-
result = await session.initialize()
162-
assert isinstance(result, InitializeResult)
163-
assert result.server_info.name == SERVER_NAME
16472

165-
# Test ping
166-
ping_result = await session.send_ping()
167-
assert isinstance(ping_result, EmptyResult)
73+
async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult:
74+
parsed = urlparse(str(params.uri))
75+
if parsed.scheme == "foobar":
76+
return ReadResourceResult(
77+
contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")]
78+
)
79+
elif parsed.scheme == "slow":
80+
# Block indefinitely so the client-side fail_after() fires; the pending
81+
# server task is cancelled when the Client context manager exits.
82+
await anyio.sleep_forever()
83+
raise MCPError(code=404, message="OOPS! no resource with that URI was found")
84+
85+
86+
@pytest.fixture
87+
def server() -> Server:
88+
return Server(SERVER_NAME, on_read_resource=handle_read_resource)
16889

16990

170-
@pytest.mark.anyio
171-
async def test_ws_client_happy_request_and_response(
172-
initialized_ws_client_session: ClientSession,
173-
) -> None:
174-
"""Test a successful request and response via WebSocket"""
175-
result = await initialized_ws_client_session.read_resource("foobar://example")
176-
assert isinstance(result, ReadResourceResult)
177-
assert isinstance(result.contents, list)
178-
assert len(result.contents) > 0
179-
assert isinstance(result.contents[0], TextResourceContents)
180-
assert result.contents[0].text == "Read example"
181-
182-
183-
@pytest.mark.anyio
184-
async def test_ws_client_exception_handling(
185-
initialized_ws_client_session: ClientSession,
186-
) -> None:
187-
"""Test exception handling in WebSocket communication"""
188-
with pytest.raises(MCPError) as exc_info:
189-
await initialized_ws_client_session.read_resource("unknown://example")
190-
assert exc_info.value.error.code == 404
191-
192-
193-
@pytest.mark.anyio
194-
async def test_ws_client_timeout(
195-
initialized_ws_client_session: ClientSession,
196-
) -> None:
197-
"""Test timeout handling in WebSocket communication"""
198-
# Set a very short timeout to trigger a timeout exception
199-
with pytest.raises(TimeoutError):
200-
with anyio.fail_after(0.1): # 100ms timeout
201-
await initialized_ws_client_session.read_resource("slow://example")
202-
203-
# Now test that we can still use the session after a timeout
204-
with anyio.fail_after(5): # Longer timeout to allow completion
205-
result = await initialized_ws_client_session.read_resource("foobar://example")
91+
async def test_ws_client_happy_request_and_response(server: Server) -> None:
92+
async with Client(server) as client:
93+
result = await client.read_resource("foobar://example")
20694
assert isinstance(result, ReadResourceResult)
20795
assert isinstance(result.contents, list)
20896
assert len(result.contents) > 0
20997
assert isinstance(result.contents[0], TextResourceContents)
21098
assert result.contents[0].text == "Read example"
99+
100+
101+
async def test_ws_client_exception_handling(server: Server) -> None:
102+
async with Client(server) as client:
103+
with pytest.raises(MCPError) as exc_info:
104+
await client.read_resource("unknown://example")
105+
assert exc_info.value.error.code == 404
106+
107+
108+
async def test_ws_client_timeout(server: Server) -> None:
109+
async with Client(server) as client:
110+
with pytest.raises(TimeoutError):
111+
with anyio.fail_after(0.1):
112+
await client.read_resource("slow://example")
113+
114+
# Session remains usable after a client-side timeout abandons a request.
115+
with anyio.fail_after(5):
116+
result = await client.read_resource("foobar://example")
117+
assert isinstance(result, ReadResourceResult)
118+
assert isinstance(result.contents, list)
119+
assert len(result.contents) > 0
120+
assert isinstance(result.contents[0], TextResourceContents)
121+
assert result.contents[0].text == "Read example"

tests/test_helpers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,73 @@
11
"""Common test utilities for MCP server tests."""
22

33
import socket
4+
import threading
45
import time
6+
from collections.abc import Generator
7+
from contextlib import contextmanager
8+
from typing import Any
9+
10+
import uvicorn
11+
12+
# How long to wait for the uvicorn server thread to reach `started`.
13+
# Generous to absorb CI scheduling delays — actual startup is typically <100ms.
14+
_SERVER_START_TIMEOUT_S = 20.0
15+
_SERVER_SHUTDOWN_TIMEOUT_S = 5.0
16+
17+
18+
@contextmanager
19+
def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]:
20+
"""Run a uvicorn server in a background thread with an ephemeral port.
21+
22+
This eliminates the TOCTOU race that occurs when a test picks a free port
23+
with ``socket.bind((host, 0))``, releases it, then starts a server hoping
24+
to rebind the same port — between release and rebind, another pytest-xdist
25+
worker may claim it, causing connection errors or cross-test contamination.
26+
27+
With ``port=0``, the OS atomically assigns a free port at bind time; the
28+
server holds it from that moment until shutdown. We read the actual port
29+
back from uvicorn's bound socket after startup completes.
30+
31+
Args:
32+
app: ASGI application to serve.
33+
**config_kwargs: Additional keyword arguments for :class:`uvicorn.Config`
34+
(e.g. ``log_level``, ``limit_concurrency``). ``host`` defaults to
35+
``127.0.0.1`` and ``port`` is forced to 0.
36+
37+
Yields:
38+
The base URL of the running server, e.g. ``http://127.0.0.1:54321``.
39+
40+
Raises:
41+
TimeoutError: If the server does not start within 20 seconds.
42+
RuntimeError: If the server thread dies during startup.
43+
"""
44+
config_kwargs.setdefault("host", "127.0.0.1")
45+
config_kwargs.setdefault("log_level", "error")
46+
config = uvicorn.Config(app=app, port=0, **config_kwargs)
47+
server = uvicorn.Server(config=config)
48+
49+
thread = threading.Thread(target=server.run, daemon=True)
50+
thread.start()
51+
52+
# uvicorn sets `server.started = True` at the end of `Server.startup()`,
53+
# after sockets are bound and the lifespan startup phase has completed.
54+
start = time.monotonic()
55+
while not server.started:
56+
if time.monotonic() - start > _SERVER_START_TIMEOUT_S: # pragma: no cover
57+
raise TimeoutError(f"uvicorn server failed to start within {_SERVER_START_TIMEOUT_S}s")
58+
if not thread.is_alive(): # pragma: no cover
59+
raise RuntimeError("uvicorn server thread exited during startup")
60+
time.sleep(0.001)
61+
62+
# server.servers[0] is the asyncio.Server; its bound socket has the real port
63+
port = server.servers[0].sockets[0].getsockname()[1]
64+
host = config.host
65+
66+
try:
67+
yield f"http://{host}:{port}"
68+
finally:
69+
server.should_exit = True
70+
thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S)
571

672

773
def wait_for_server(port: int, timeout: float = 20.0) -> None:

0 commit comments

Comments
 (0)