Skip to content

Commit 694cabc

Browse files
committed
Fix: Prevent session manager shutdown on individual session crash
1 parent 05b7156 commit 694cabc

2 files changed

Lines changed: 164 additions & 13 deletions

File tree

src/mcp/server/streamable_http_manager.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ class StreamableHTTPSessionManager:
5151
json_response: Whether to use JSON responses instead of SSE streams
5252
stateless: If True, creates a completely fresh transport for each request
5353
with no session tracking or state persistence between requests.
54-
5554
"""
5655

5756
def __init__(
@@ -171,12 +170,15 @@ async def run_stateless_server(
171170
async with http_transport.connect() as streams:
172171
read_stream, write_stream = streams
173172
task_status.started()
174-
await self.app.run(
175-
read_stream,
176-
write_stream,
177-
self.app.create_initialization_options(),
178-
stateless=True,
179-
)
173+
try:
174+
await self.app.run(
175+
read_stream,
176+
write_stream,
177+
self.app.create_initialization_options(),
178+
stateless=True,
179+
)
180+
except Exception as e:
181+
logger.error(f"Stateless session crashed: {e}", exc_info=True)
180182

181183
# Assert task group is not None for type checking
182184
assert self._task_group is not None
@@ -235,12 +237,33 @@ async def run_server(
235237
async with http_transport.connect() as streams:
236238
read_stream, write_stream = streams
237239
task_status.started()
238-
await self.app.run(
239-
read_stream,
240-
write_stream,
241-
self.app.create_initialization_options(),
242-
stateless=False, # Stateful mode
243-
)
240+
try:
241+
await self.app.run(
242+
read_stream,
243+
write_stream,
244+
self.app.create_initialization_options(),
245+
stateless=False, # Stateful mode
246+
)
247+
except Exception as e:
248+
logger.error(
249+
f"Session {http_transport.mcp_session_id} crashed: {e}",
250+
exc_info=True,
251+
)
252+
finally:
253+
# Cleanup logic
254+
if (
255+
http_transport.mcp_session_id
256+
and http_transport.mcp_session_id
257+
in self._server_instances
258+
):
259+
logger.info(
260+
"Cleaning up crashed/terminated session "
261+
f"{http_transport.mcp_session_id} from "
262+
"active instances."
263+
)
264+
del self._server_instances[
265+
http_transport.mcp_session_id
266+
]
244267

245268
# Assert task group is not None for type checking
246269
assert self._task_group is not None

tests/server/test_streamable_http_manager.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Tests for StreamableHTTPSessionManager."""
22

3+
from unittest.mock import AsyncMock
4+
35
import anyio
46
import pytest
57

68
from mcp.server.lowlevel import Server
9+
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
710
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
811

912

@@ -79,3 +82,128 @@ async def send(message):
7982
assert "Task group is not initialized. Make sure to use run()." in str(
8083
excinfo.value
8184
)
85+
86+
87+
class TestException(Exception):
88+
__test__ = False # Prevent pytest from collecting this as a test class
89+
pass
90+
91+
92+
@pytest.fixture
93+
async def running_manager():
94+
app = Server("test-cleanup-server")
95+
# It's important that the app instance used by the manager is the one we can patch
96+
manager = StreamableHTTPSessionManager(app=app)
97+
async with manager.run():
98+
# Patch app.run here if it's simpler, or patch it within the test
99+
yield manager, app
100+
101+
102+
@pytest.mark.anyio
103+
async def test_stateful_session_cleanup_on_graceful_exit(running_manager):
104+
manager, app = running_manager
105+
106+
mock_mcp_run = AsyncMock(return_value=None)
107+
# This will be called by StreamableHTTPSessionManager's run_server -> self.app.run
108+
app.run = mock_mcp_run
109+
110+
sent_messages = []
111+
112+
async def mock_send(message):
113+
sent_messages.append(message)
114+
115+
scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []}
116+
117+
async def mock_receive():
118+
return {"type": "http.request", "body": b"", "more_body": False}
119+
120+
# Trigger session creation
121+
await manager.handle_request(scope, mock_receive, mock_send)
122+
123+
# Extract session ID from response headers
124+
session_id = None
125+
for msg in sent_messages:
126+
if msg["type"] == "http.response.start":
127+
for header_name, header_value in msg.get("headers", []):
128+
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
129+
session_id = header_value.decode()
130+
break
131+
if session_id: # Break outer loop if session_id is found
132+
break
133+
134+
assert session_id is not None, "Session ID not found in response headers"
135+
136+
# Ensure MCPServer.run was called
137+
mock_mcp_run.assert_called_once()
138+
139+
# At this point, mock_mcp_run has completed, and the finally block in
140+
# StreamableHTTPSessionManager's run_server should have executed.
141+
142+
# To ensure the task spawned by handle_request finishes and cleanup occurs:
143+
# Give other tasks a chance to run. This is important for the finally block.
144+
await anyio.sleep(0.01)
145+
146+
assert (
147+
session_id not in manager._server_instances
148+
), "Session ID should be removed from _server_instances after graceful exit"
149+
assert (
150+
not manager._server_instances
151+
), "No sessions should be tracked after the only session exits gracefully"
152+
153+
154+
@pytest.mark.anyio
155+
async def test_stateful_session_cleanup_on_exception(running_manager):
156+
manager, app = running_manager
157+
158+
mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash"))
159+
app.run = mock_mcp_run
160+
161+
sent_messages = []
162+
163+
async def mock_send(message):
164+
sent_messages.append(message)
165+
# If an exception occurs, the transport might try to send an error response
166+
# For this test, we mostly care that the session is established enough
167+
# to get an ID
168+
if message["type"] == "http.response.start" and message["status"] >= 500:
169+
pass # Expected if TestException propagates that far up the transport
170+
171+
scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []}
172+
173+
async def mock_receive():
174+
return {"type": "http.request", "body": b"", "more_body": False}
175+
176+
# It's possible handle_request itself might raise an error if the TestException
177+
# isn't caught by the transport layer before propagating.
178+
# The key is that the session manager's internal task for MCPServer.run
179+
# encounters the exception.
180+
try:
181+
await manager.handle_request(scope, mock_receive, mock_send)
182+
except TestException:
183+
# This might be caught here if not handled by StreamableHTTPServerTransport's
184+
# error handling
185+
pass
186+
187+
session_id = None
188+
for msg in sent_messages:
189+
if msg["type"] == "http.response.start":
190+
for header_name, header_value in msg.get("headers", []):
191+
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
192+
session_id = header_value.decode()
193+
break
194+
if session_id: # Break outer loop if session_id is found
195+
break
196+
197+
assert session_id is not None, "Session ID not found in response headers"
198+
199+
mock_mcp_run.assert_called_once()
200+
201+
# Give other tasks a chance to run to ensure the finally block executes
202+
await anyio.sleep(0.01)
203+
204+
assert (
205+
session_id not in manager._server_instances
206+
), "Session ID should be removed from _server_instances after an exception"
207+
assert (
208+
not manager._server_instances
209+
), "No sessions should be tracked after the only session crashes"

0 commit comments

Comments
 (0)