Skip to content

Commit f951e5c

Browse files
committed
fix: stateless HTTP task leak and graceful SSE drain on shutdown
Upstream PR: modelcontextprotocol#2145
1 parent 7ba41dc commit f951e5c

File tree

3 files changed

+295
-23
lines changed

3 files changed

+295
-23
lines changed

src/mcp/server/streamable_http.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,14 @@ async def sse_writer(): # pragma: lax no cover
619619
# Then send the message to be processed by the server
620620
session_message = self._create_session_message(message, request, request_id, protocol_version)
621621
await writer.send(session_message)
622-
except Exception: # pragma: no cover
622+
except Exception: # pragma: lax no cover
623623
logger.exception("SSE response error")
624624
await sse_stream_writer.aclose()
625625
await self._clean_up_memory_streams(request_id)
626626
finally:
627627
await sse_stream_reader.aclose()
628628

629-
except Exception as err: # pragma: no cover
629+
except Exception as err: # pragma: lax no cover
630630
logger.exception("Error handling POST request")
631631
response = self._create_error_response(
632632
f"Error handling POST request: {err}",
@@ -809,7 +809,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool:
809809

810810
async def _validate_session(self, request: Request, send: Send) -> bool:
811811
"""Validate the session ID in the request."""
812-
if not self.mcp_session_id: # pragma: no cover
812+
if not self.mcp_session_id: # pragma: lax no cover
813813
# If we're not using session IDs, return True
814814
return True
815815

@@ -1019,7 +1019,7 @@ async def message_router():
10191019
)
10201020
except anyio.ClosedResourceError:
10211021
if self._terminated:
1022-
logger.debug("Read stream closed by client")
1022+
logger.debug("Read stream closed by client") # pragma: lax no cover
10231023
else:
10241024
logger.exception("Unexpected closure of read stream in message router")
10251025
except Exception: # pragma: lax no cover

src/mcp/server/streamable_http_manager.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
9292

93+
# Track in-flight stateless transports for graceful shutdown
94+
self._stateless_transports: set[StreamableHTTPServerTransport] = set()
95+
9396
# The task group will be set during lifespan
9497
self._task_group = None
9598
# Thread-safe tracking of run() calls
@@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
130133
yield # Let the application run
131134
finally:
132135
logger.info("StreamableHTTP session manager shutting down")
136+
137+
# Terminate all active transports before cancelling the task
138+
# group. This closes their in-memory streams, which lets
139+
# EventSourceResponse send a final ``more_body=False`` chunk
140+
# — a clean HTTP close instead of a connection reset.
141+
for transport in list(self._server_instances.values()):
142+
try:
143+
await transport.terminate()
144+
except Exception: # pragma: no cover
145+
logger.debug("Error terminating transport during shutdown", exc_info=True)
146+
for transport in list(self._stateless_transports):
147+
try:
148+
await transport.terminate()
149+
except Exception: # pragma: no cover
150+
logger.debug("Error terminating stateless transport during shutdown", exc_info=True)
151+
133152
# Cancel task group to stop all spawned tasks
134153
tg.cancel_scope.cancel()
135154
self._task_group = None
136155
# Clear any remaining server instances
137156
self._server_instances.clear()
157+
self._stateless_transports.clear()
138158

139159
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140160
"""Process ASGI request with proper session handling and transport setup.
@@ -151,7 +171,12 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
151171
await self._handle_stateful_request(scope, receive, send)
152172

153173
async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None:
154-
"""Process request in stateless mode - creating a new transport for each request."""
174+
"""Process request in stateless mode - creating a new transport for each request.
175+
176+
Uses a request-scoped task group so the server task is automatically
177+
cancelled when the request completes, preventing task accumulation in
178+
the manager's global task group.
179+
"""
155180
logger.debug("Stateless mode: Creating new transport for this request")
156181
# No session ID needed in stateless mode
157182
http_transport = StreamableHTTPServerTransport(
@@ -161,6 +186,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send:
161186
security_settings=self.security_settings,
162187
)
163188

189+
# Track for graceful shutdown
190+
self._stateless_transports.add(http_transport)
191+
164192
# Start server in a new task
165193
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
166194
async with http_transport.connect() as streams:
@@ -173,18 +201,27 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
173201
self.app.create_initialization_options(),
174202
stateless=True,
175203
)
176-
except Exception: # pragma: no cover
204+
except Exception: # pragma: lax no cover
177205
logger.exception("Stateless session crashed")
178206

179-
# Assert task group is not None for type checking
180-
assert self._task_group is not None
181-
# Start the server task
182-
await self._task_group.start(run_stateless_server)
183-
184-
# Handle the HTTP request and return the response
185-
await http_transport.handle_request(scope, receive, send)
186-
187-
# Terminate the transport after the request is handled
207+
# Use a request-scoped task group instead of the global one.
208+
# This ensures the server task is cancelled when the request
209+
# finishes, preventing zombie tasks from accumulating.
210+
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1764
211+
try:
212+
async with anyio.create_task_group() as request_tg:
213+
await request_tg.start(run_stateless_server)
214+
# Handle the HTTP request directly in the caller's context
215+
# (not as a child task) so execution flows back naturally.
216+
await http_transport.handle_request(scope, receive, send)
217+
# Cancel the request-scoped task group to stop the server task.
218+
request_tg.cancel_scope.cancel()
219+
finally:
220+
self._stateless_transports.discard(http_transport)
221+
222+
# Terminate after the task group exits — the server task is already
223+
# cancelled at this point, so this is just cleanup (sets _terminated
224+
# flag and closes any remaining streams).
188225
await http_transport.terminate()
189226

190227
async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -272,7 +309,6 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
272309
# Unknown or expired session ID - return 404 per MCP spec
273310
# TODO: Align error code once spec clarifies
274311
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
275-
logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}")
276312
error_response = JSONRPCError(
277313
jsonrpc="2.0",
278314
id=None,

0 commit comments

Comments
 (0)