-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_streamable_http_manager.py
More file actions
262 lines (194 loc) · 8.78 KB
/
test_streamable_http_manager.py
File metadata and controls
262 lines (194 loc) · 8.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Tests for StreamableHTTPSessionManager."""
from unittest.mock import AsyncMock, patch
import anyio
import pytest
from mcp.server import streamable_http_manager
from mcp.server.lowlevel import Server
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
@pytest.mark.anyio
async def test_run_can_only_be_called_once():
"""Test that run() can only be called once per instance."""
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app)
# First call should succeed
async with manager.run():
pass
# Second call should raise RuntimeError
with pytest.raises(RuntimeError) as excinfo:
async with manager.run():
pass
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value)
@pytest.mark.anyio
async def test_run_prevents_concurrent_calls():
"""Test that concurrent calls to run() are prevented."""
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app)
errors = []
async def try_run():
try:
async with manager.run():
# Simulate some work
await anyio.sleep(0.1)
except RuntimeError as e:
errors.append(e)
# Try to run concurrently
async with anyio.create_task_group() as tg:
tg.start_soon(try_run)
tg.start_soon(try_run)
# One should succeed, one should fail
assert len(errors) == 1
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])
@pytest.mark.anyio
async def test_handle_request_without_run_raises_error():
"""Test that handle_request raises error if run() hasn't been called."""
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app)
# Mock ASGI parameters
scope = {"type": "http", "method": "POST", "path": "/test"}
async def receive():
return {"type": "http.request", "body": b""}
async def send(message):
pass
# Should raise error because run() hasn't been called
with pytest.raises(RuntimeError) as excinfo:
await manager.handle_request(scope, receive, send)
assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value)
class TestException(Exception):
__test__ = False # Prevent pytest from collecting this as a test class
pass
@pytest.fixture
async def running_manager():
app = Server("test-cleanup-server")
# It's important that the app instance used by the manager is the one we can patch
manager = StreamableHTTPSessionManager(app=app)
async with manager.run():
# Patch app.run here if it's simpler, or patch it within the test
yield manager, app
@pytest.mark.anyio
async def test_stateful_session_cleanup_on_graceful_exit(running_manager):
manager, app = running_manager
mock_mcp_run = AsyncMock(return_value=None)
# This will be called by StreamableHTTPSessionManager's run_server -> self.app.run
app.run = mock_mcp_run
sent_messages = []
async def mock_send(message):
sent_messages.append(message)
scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": [(b"content-type", b"application/json")],
}
async def mock_receive():
return {"type": "http.request", "body": b"", "more_body": False}
# Trigger session creation
await manager.handle_request(scope, mock_receive, mock_send)
# Extract session ID from response headers
session_id = None
for msg in sent_messages:
if msg["type"] == "http.response.start":
for header_name, header_value in msg.get("headers", []):
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
session_id = header_value.decode()
break
if session_id: # Break outer loop if session_id is found
break
assert session_id is not None, "Session ID not found in response headers"
# Ensure MCPServer.run was called
mock_mcp_run.assert_called_once()
# At this point, mock_mcp_run has completed, and the finally block in
# StreamableHTTPSessionManager's run_server should have executed.
# To ensure the task spawned by handle_request finishes and cleanup occurs:
# Give other tasks a chance to run. This is important for the finally block.
await anyio.sleep(0.01)
assert session_id not in manager._server_instances, (
"Session ID should be removed from _server_instances after graceful exit"
)
assert not manager._server_instances, "No sessions should be tracked after the only session exits gracefully"
@pytest.mark.anyio
async def test_stateful_session_cleanup_on_exception(running_manager):
manager, app = running_manager
mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash"))
app.run = mock_mcp_run
sent_messages = []
async def mock_send(message):
sent_messages.append(message)
# If an exception occurs, the transport might try to send an error response
# For this test, we mostly care that the session is established enough
# to get an ID
if message["type"] == "http.response.start" and message["status"] >= 500:
pass # Expected if TestException propagates that far up the transport
scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": [(b"content-type", b"application/json")],
}
async def mock_receive():
return {"type": "http.request", "body": b"", "more_body": False}
# Trigger session creation
await manager.handle_request(scope, mock_receive, mock_send)
session_id = None
for msg in sent_messages:
if msg["type"] == "http.response.start":
for header_name, header_value in msg.get("headers", []):
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
session_id = header_value.decode()
break
if session_id: # Break outer loop if session_id is found
break
assert session_id is not None, "Session ID not found in response headers"
mock_mcp_run.assert_called_once()
# Give other tasks a chance to run to ensure the finally block executes
await anyio.sleep(0.01)
assert session_id not in manager._server_instances, (
"Session ID should be removed from _server_instances after an exception"
)
assert not manager._server_instances, "No sessions should be tracked after the only session crashes"
@pytest.mark.anyio
async def test_stateless_requests_memory_cleanup():
"""Test that stateless requests actually clean up resources using real transports."""
app = Server("test-stateless-real-cleanup")
manager = StreamableHTTPSessionManager(app=app, stateless=True)
# Track created transport instances
created_transports = []
# Patch StreamableHTTPServerTransport constructor to track instances
original_constructor = streamable_http_manager.StreamableHTTPServerTransport
def track_transport(*args, **kwargs):
transport = original_constructor(*args, **kwargs)
created_transports.append(transport)
return transport
with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport):
async with manager.run():
# Mock app.run to complete immediately
app.run = AsyncMock(return_value=None)
# Send a simple request
sent_messages = []
async def mock_send(message):
sent_messages.append(message)
scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": [
(b"content-type", b"application/json"),
(b"accept", b"application/json, text/event-stream"),
],
}
# Empty body to trigger early return
async def mock_receive():
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
# Send a request
await manager.handle_request(scope, mock_receive, mock_send)
# Verify transport was created
assert len(created_transports) == 1, "Should have created one transport"
transport = created_transports[0]
# The key assertion - transport should be terminated
assert transport._terminated, "Transport should be terminated after stateless request"
# Verify internal state is cleaned up
assert len(transport._request_streams) == 0, "Transport should have no active request streams"