-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathtest_cancel_handling.py
More file actions
250 lines (207 loc) · 9.6 KB
/
test_cancel_handling.py
File metadata and controls
250 lines (207 loc) · 9.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""Test that cancelled requests don't cause double responses."""
import anyio
import pytest
from mcp import Client
from mcp.server import Server, ServerRequestContext
from mcp.shared.exceptions import MCPError
from mcp.shared.message import SessionMessage
from mcp.types import (
LATEST_PROTOCOL_VERSION,
CallToolRequest,
CallToolRequestParams,
CallToolResult,
CancelledNotification,
CancelledNotificationParams,
ClientCapabilities,
Implementation,
InitializeRequestParams,
JSONRPCNotification,
JSONRPCRequest,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Tool,
)
@pytest.mark.anyio
async def test_server_remains_functional_after_cancel():
"""Verify server can handle new requests after a cancellation."""
# Track tool calls
call_count = 0
ev_first_call = anyio.Event()
first_request_id = None
async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
return ListToolsResult(
tools=[
Tool(
name="test_tool",
description="Tool for testing",
input_schema={},
)
]
)
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
nonlocal call_count, first_request_id
if params.name == "test_tool":
call_count += 1
if call_count == 1:
first_request_id = ctx.request_id
ev_first_call.set()
await anyio.sleep(5) # First call is slow
return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")])
raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover
server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool)
async with Client(server) as client:
# First request (will be cancelled)
async def first_request():
try:
await client.session.send_request(
CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})),
CallToolResult,
)
pytest.fail("First request should have been cancelled") # pragma: no cover
except MCPError:
pass # Expected
# Start first request
async with anyio.create_task_group() as tg:
tg.start_soon(first_request)
# Wait for it to start
await ev_first_call.wait()
# Cancel it
assert first_request_id is not None
await client.session.send_notification(
CancelledNotification(
params=CancelledNotificationParams(request_id=first_request_id, reason="Testing server recovery"),
)
)
# Second request (should work normally)
result = await client.call_tool("test_tool", {})
# Verify second request completed successfully
assert len(result.content) == 1
# Type narrowing for pyright
content = result.content[0]
assert content.type == "text"
assert isinstance(content, TextContent)
assert content.text == "Call number: 2"
assert call_count == 2
@pytest.mark.anyio
async def test_server_cancels_in_flight_handlers_on_transport_close():
"""When the transport closes mid-request, server.run() must cancel in-flight
handlers rather than join on them.
Without the cancel, the task group waits for the handler, which then tries
to respond through a write stream that _receive_loop already closed,
raising ClosedResourceError and crashing server.run() with exit code 1.
This drives server.run() with raw memory streams because InMemoryTransport
wraps it in its own finally-cancel (_memory.py) which masks the bug.
"""
handler_started = anyio.Event()
handler_cancelled = anyio.Event()
server_run_returned = anyio.Event()
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
try:
await anyio.sleep_forever()
finally:
handler_cancelled.set()
# unreachable: sleep_forever only exits via cancellation
raise AssertionError # pragma: no cover
server = Server("test", on_call_tool=handle_call_tool)
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
async def run_server():
await server.run(server_read, server_write, server.create_initialization_options())
server_run_returned.set()
init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)
with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)
await to_server.send(SessionMessage(init_req))
await from_server.receive() # init response
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))
await handler_started.wait()
# Close the server's input stream — this is what stdin EOF does.
# server.run()'s incoming_messages loop ends, finally-cancel fires,
# handler gets CancelledError, server.run() returns.
await to_server.aclose()
await server_run_returned.wait()
assert handler_cancelled.is_set()
@pytest.mark.anyio
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
"""When the transport closes while handlers are blocked on server→client
requests (sampling, roots, elicitation), server.run() must still exit cleanly.
Two bugs covered:
1. _receive_loop's finally iterates _response_streams with await checkpoints
inside; the woken handler's send_request finally pops from that dict
before the next __next__() — RuntimeError: dictionary changed size.
2. The woken handler's MCPError is caught in _handle_request, which falls
through to respond() against a write stream _receive_loop already closed.
"""
handlers_started = 0
both_started = anyio.Event()
server_run_returned = anyio.Event()
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
nonlocal handlers_started
handlers_started += 1
if handlers_started == 2:
both_started.set()
# Blocks on send_request waiting for a client response that never comes.
# _receive_loop's finally will wake this with CONNECTION_CLOSED.
await ctx.session.list_roots()
raise AssertionError # pragma: no cover
server = Server("test", on_call_tool=handle_call_tool)
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
async def run_server():
await server.run(server_read, server_write, server.create_initialization_options())
server_run_returned.set()
init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)
await to_server.send(SessionMessage(init_req))
await from_server.receive() # init response
await to_server.send(SessionMessage(initialized))
# Two tool calls → two handlers → two _response_streams entries.
for rid in (2, 3):
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=rid,
method="tools/call",
params=CallToolRequestParams(name="t", arguments={}).model_dump(by_alias=True, mode="json"),
)
await to_server.send(SessionMessage(call_req))
await both_started.wait()
# Drain the two roots/list requests so send_request's _write_stream.send()
# completes and both handlers are parked at response_stream_reader.receive().
await from_server.receive()
await from_server.receive()
await to_server.aclose()
# Without the fixes: RuntimeError (dict mutation) or ClosedResourceError
# (respond after write-stream close) escapes run_server and this hangs.
await server_run_returned.wait()