|
1 | | -import multiprocessing |
2 | | -import socket |
3 | | -from collections.abc import AsyncGenerator, Generator |
| 1 | +"""Tests for the WebSocket transport. |
| 2 | +
|
| 3 | +The smoke test (``test_ws_client_basic_connection``) runs the full WS stack |
| 4 | +end-to-end over a real TCP connection and is what provides coverage of |
| 5 | +``src/mcp/client/websocket.py``. |
| 6 | +
|
| 7 | +The remaining tests verify transport-agnostic MCP semantics (error |
| 8 | +propagation, client-side timeouts) and use the in-memory ``Client`` transport |
| 9 | +to avoid the cost and flakiness of real network servers. |
| 10 | +""" |
| 11 | + |
| 12 | +from collections.abc import Generator |
4 | 13 | from urllib.parse import urlparse |
5 | 14 |
|
6 | 15 | import anyio |
7 | 16 | import pytest |
8 | | -import uvicorn |
9 | 17 | from starlette.applications import Starlette |
10 | 18 | from starlette.routing import WebSocketRoute |
11 | 19 | from starlette.websockets import WebSocket |
12 | 20 |
|
13 | | -from mcp import MCPError |
| 21 | +from mcp import Client, MCPError |
14 | 22 | from mcp.client.session import ClientSession |
15 | 23 | from mcp.client.websocket import websocket_client |
16 | 24 | from mcp.server import Server, ServerRequestContext |
17 | 25 | from mcp.server.websocket import websocket_server |
18 | 26 | from mcp.types import ( |
19 | | - CallToolRequestParams, |
20 | | - CallToolResult, |
21 | 27 | EmptyResult, |
22 | 28 | InitializeResult, |
23 | | - ListToolsResult, |
24 | | - PaginatedRequestParams, |
25 | 29 | ReadResourceRequestParams, |
26 | 30 | ReadResourceResult, |
27 | | - TextContent, |
28 | 31 | TextResourceContents, |
29 | | - Tool, |
30 | 32 | ) |
31 | | -from tests.test_helpers import wait_for_server |
| 33 | +from tests.test_helpers import run_uvicorn_in_thread |
32 | 34 |
|
33 | 35 | SERVER_NAME = "test_server_for_WS" |
34 | 36 |
|
| 37 | +pytestmark = pytest.mark.anyio |
35 | 38 |
|
36 | | -@pytest.fixture |
37 | | -def server_port() -> int: |
38 | | - with socket.socket() as s: |
39 | | - s.bind(("127.0.0.1", 0)) |
40 | | - return s.getsockname()[1] |
41 | 39 |
|
| 40 | +# --- WebSocket transport smoke test (real TCP) ------------------------------- |
42 | 41 |
|
43 | | -@pytest.fixture |
44 | | -def server_url(server_port: int) -> str: |
45 | | - return f"ws://127.0.0.1:{server_port}" |
46 | 42 |
|
| 43 | +def make_server_app() -> Starlette: |
| 44 | + srv = Server(SERVER_NAME) |
47 | 45 |
|
48 | | -async def handle_read_resource( # pragma: no cover |
49 | | - ctx: ServerRequestContext, params: ReadResourceRequestParams |
50 | | -) -> ReadResourceResult: |
51 | | - parsed = urlparse(str(params.uri)) |
52 | | - if parsed.scheme == "foobar": |
53 | | - return ReadResourceResult( |
54 | | - contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] |
55 | | - ) |
56 | | - elif parsed.scheme == "slow": |
57 | | - await anyio.sleep(2.0) |
58 | | - return ReadResourceResult( |
59 | | - contents=[ |
60 | | - TextResourceContents( |
61 | | - uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" |
62 | | - ) |
63 | | - ] |
64 | | - ) |
65 | | - raise MCPError(code=404, message="OOPS! no resource with that URI was found") |
66 | | - |
67 | | - |
68 | | -async def handle_list_tools( # pragma: no cover |
69 | | - ctx: ServerRequestContext, params: PaginatedRequestParams | None |
70 | | -) -> ListToolsResult: |
71 | | - return ListToolsResult( |
72 | | - tools=[ |
73 | | - Tool( |
74 | | - name="test_tool", |
75 | | - description="A test tool", |
76 | | - input_schema={"type": "object", "properties": {}}, |
77 | | - ) |
78 | | - ] |
79 | | - ) |
80 | | - |
81 | | - |
82 | | -async def handle_call_tool( # pragma: no cover |
83 | | - ctx: ServerRequestContext, params: CallToolRequestParams |
84 | | -) -> CallToolResult: |
85 | | - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) |
86 | | - |
87 | | - |
88 | | -def _create_server() -> Server: # pragma: no cover |
89 | | - return Server( |
90 | | - SERVER_NAME, |
91 | | - on_read_resource=handle_read_resource, |
92 | | - on_list_tools=handle_list_tools, |
93 | | - on_call_tool=handle_call_tool, |
94 | | - ) |
95 | | - |
96 | | - |
97 | | -# Test fixtures |
98 | | -def make_server_app() -> Starlette: # pragma: no cover |
99 | | - """Create test Starlette app with WebSocket transport""" |
100 | | - server = _create_server() |
101 | | - |
102 | | - async def handle_ws(websocket: WebSocket): |
| 46 | + async def handle_ws(websocket: WebSocket) -> None: |
103 | 47 | async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: |
104 | | - await server.run(streams[0], streams[1], server.create_initialization_options()) |
105 | | - |
106 | | - app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) |
107 | | - return app |
108 | | - |
| 48 | + await srv.run(streams[0], streams[1], srv.create_initialization_options()) |
109 | 49 |
|
110 | | -def run_server(server_port: int) -> None: # pragma: no cover |
111 | | - app = make_server_app() |
112 | | - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) |
113 | | - print(f"starting server on {server_port}") |
114 | | - server.run() |
| 50 | + return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) |
115 | 51 |
|
116 | 52 |
|
117 | | -@pytest.fixture() |
118 | | -def server(server_port: int) -> Generator[None, None, None]: |
119 | | - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) |
120 | | - print("starting process") |
121 | | - proc.start() |
122 | | - |
123 | | - # Wait for server to be running |
124 | | - print("waiting for server to start") |
125 | | - wait_for_server(server_port) |
126 | | - |
127 | | - yield |
128 | | - |
129 | | - print("killing server") |
130 | | - # Signal the server to stop |
131 | | - proc.kill() |
132 | | - proc.join(timeout=2) |
133 | | - if proc.is_alive(): # pragma: no cover |
134 | | - print("server process failed to terminate") |
| 53 | +@pytest.fixture |
| 54 | +def ws_server_url() -> Generator[str, None, None]: |
| 55 | + with run_uvicorn_in_thread(make_server_app()) as base_url: |
| 56 | + yield base_url.replace("http://", "ws://") + "/ws" |
135 | 57 |
|
136 | 58 |
|
137 | | -@pytest.fixture() |
138 | | -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: |
139 | | - """Create and initialize a WebSocket client session""" |
140 | | - async with websocket_client(server_url + "/ws") as streams: |
| 59 | +async def test_ws_client_basic_connection(ws_server_url: str) -> None: |
| 60 | + async with websocket_client(ws_server_url) as streams: |
141 | 61 | async with ClientSession(*streams) as session: |
142 | | - # Test initialization |
143 | 62 | result = await session.initialize() |
144 | 63 | assert isinstance(result, InitializeResult) |
145 | 64 | assert result.server_info.name == SERVER_NAME |
146 | 65 |
|
147 | | - # Test ping |
148 | 66 | ping_result = await session.send_ping() |
149 | 67 | assert isinstance(ping_result, EmptyResult) |
150 | 68 |
|
151 | | - yield session |
152 | 69 |
|
| 70 | +# --- In-memory tests (transport-agnostic MCP semantics) ---------------------- |
153 | 71 |
|
154 | | -# Tests |
155 | | -@pytest.mark.anyio |
156 | | -async def test_ws_client_basic_connection(server: None, server_url: str) -> None: |
157 | | - """Test the WebSocket connection establishment""" |
158 | | - async with websocket_client(server_url + "/ws") as streams: |
159 | | - async with ClientSession(*streams) as session: |
160 | | - # Test initialization |
161 | | - result = await session.initialize() |
162 | | - assert isinstance(result, InitializeResult) |
163 | | - assert result.server_info.name == SERVER_NAME |
164 | 72 |
|
165 | | - # Test ping |
166 | | - ping_result = await session.send_ping() |
167 | | - assert isinstance(ping_result, EmptyResult) |
| 73 | +async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: |
| 74 | + parsed = urlparse(str(params.uri)) |
| 75 | + if parsed.scheme == "foobar": |
| 76 | + return ReadResourceResult( |
| 77 | + contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] |
| 78 | + ) |
| 79 | + elif parsed.scheme == "slow": |
| 80 | + # Block indefinitely so the client-side fail_after() fires; the pending |
| 81 | + # server task is cancelled when the Client context manager exits. |
| 82 | + await anyio.sleep_forever() |
| 83 | + raise MCPError(code=404, message="OOPS! no resource with that URI was found") |
| 84 | + |
| 85 | + |
| 86 | +@pytest.fixture |
| 87 | +def server() -> Server: |
| 88 | + return Server(SERVER_NAME, on_read_resource=handle_read_resource) |
168 | 89 |
|
169 | 90 |
|
170 | | -@pytest.mark.anyio |
171 | | -async def test_ws_client_happy_request_and_response( |
172 | | - initialized_ws_client_session: ClientSession, |
173 | | -) -> None: |
174 | | - """Test a successful request and response via WebSocket""" |
175 | | - result = await initialized_ws_client_session.read_resource("foobar://example") |
176 | | - assert isinstance(result, ReadResourceResult) |
177 | | - assert isinstance(result.contents, list) |
178 | | - assert len(result.contents) > 0 |
179 | | - assert isinstance(result.contents[0], TextResourceContents) |
180 | | - assert result.contents[0].text == "Read example" |
181 | | - |
182 | | - |
183 | | -@pytest.mark.anyio |
184 | | -async def test_ws_client_exception_handling( |
185 | | - initialized_ws_client_session: ClientSession, |
186 | | -) -> None: |
187 | | - """Test exception handling in WebSocket communication""" |
188 | | - with pytest.raises(MCPError) as exc_info: |
189 | | - await initialized_ws_client_session.read_resource("unknown://example") |
190 | | - assert exc_info.value.error.code == 404 |
191 | | - |
192 | | - |
193 | | -@pytest.mark.anyio |
194 | | -async def test_ws_client_timeout( |
195 | | - initialized_ws_client_session: ClientSession, |
196 | | -) -> None: |
197 | | - """Test timeout handling in WebSocket communication""" |
198 | | - # Set a very short timeout to trigger a timeout exception |
199 | | - with pytest.raises(TimeoutError): |
200 | | - with anyio.fail_after(0.1): # 100ms timeout |
201 | | - await initialized_ws_client_session.read_resource("slow://example") |
202 | | - |
203 | | - # Now test that we can still use the session after a timeout |
204 | | - with anyio.fail_after(5): # Longer timeout to allow completion |
205 | | - result = await initialized_ws_client_session.read_resource("foobar://example") |
| 91 | +async def test_ws_client_happy_request_and_response(server: Server) -> None: |
| 92 | + async with Client(server) as client: |
| 93 | + result = await client.read_resource("foobar://example") |
206 | 94 | assert isinstance(result, ReadResourceResult) |
207 | 95 | assert isinstance(result.contents, list) |
208 | 96 | assert len(result.contents) > 0 |
209 | 97 | assert isinstance(result.contents[0], TextResourceContents) |
210 | 98 | assert result.contents[0].text == "Read example" |
| 99 | + |
| 100 | + |
| 101 | +async def test_ws_client_exception_handling(server: Server) -> None: |
| 102 | + async with Client(server) as client: |
| 103 | + with pytest.raises(MCPError) as exc_info: |
| 104 | + await client.read_resource("unknown://example") |
| 105 | + assert exc_info.value.error.code == 404 |
| 106 | + |
| 107 | + |
| 108 | +async def test_ws_client_timeout(server: Server) -> None: |
| 109 | + async with Client(server) as client: |
| 110 | + with pytest.raises(TimeoutError): |
| 111 | + with anyio.fail_after(0.1): |
| 112 | + await client.read_resource("slow://example") |
| 113 | + |
| 114 | + # Session remains usable after a client-side timeout abandons a request. |
| 115 | + with anyio.fail_after(5): |
| 116 | + result = await client.read_resource("foobar://example") |
| 117 | + assert isinstance(result, ReadResourceResult) |
| 118 | + assert isinstance(result.contents, list) |
| 119 | + assert len(result.contents) > 0 |
| 120 | + assert isinstance(result.contents[0], TextResourceContents) |
| 121 | + assert result.contents[0].text == "Read example" |
0 commit comments