-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fix leaked anyio streams in streamable_http #1991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
1606c96
bbcfbf6
de5d624
b11e04f
1f6e8ec
cf9e43e
a0afd1d
508d8e7
095b802
824a5f4
9dc0742
774380d
ef275fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -328,19 +328,19 @@ def _create_json_response( | |
| headers=response_headers, | ||
| ) | ||
|
|
||
| def _get_session_id(self, request: Request) -> str | None: # pragma: no cover | ||
| def _get_session_id(self, request: Request) -> str | None: | ||
| """Extract the session ID from request headers.""" | ||
| return request.headers.get(MCP_SESSION_ID_HEADER) | ||
|
|
||
| def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover | ||
| def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: | ||
| """Create event data dictionary from an EventMessage.""" | ||
| event_data = { | ||
| "event": "message", | ||
| "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), | ||
| } | ||
|
|
||
| # If an event ID was provided, include it | ||
| if event_message.event_id: | ||
| if event_message.event_id: # pragma: no cover | ||
| event_data["id"] = event_message.event_id | ||
|
|
||
| return event_data | ||
|
|
@@ -381,9 +381,9 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No | |
|
|
||
| if request.method == "POST": | ||
| await self._handle_post_request(scope, request, receive, send) | ||
| elif request.method == "GET": # pragma: no cover | ||
| elif request.method == "GET": | ||
| await self._handle_get_request(request, send) | ||
| elif request.method == "DELETE": # pragma: no cover | ||
| elif request.method == "DELETE": | ||
| await self._handle_delete_request(request, send) | ||
| else: # pragma: no cover | ||
| await self._handle_unsupported_request(request, send) | ||
|
|
@@ -470,14 +470,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re | |
| # Check if this is an initialization request | ||
| is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" | ||
|
|
||
| if is_initialization_request: # pragma: no cover | ||
| if is_initialization_request: | ||
| # Check if the server already has an established session | ||
| if self.mcp_session_id: | ||
| # Check if request has a session ID | ||
| request_session_id = self._get_session_id(request) | ||
|
|
||
| # If request has a session ID but doesn't match, return 404 | ||
| if request_session_id and request_session_id != self.mcp_session_id: | ||
| if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Not Found: Invalid or expired session ID", | ||
| HTTPStatus.NOT_FOUND, | ||
|
|
@@ -488,7 +488,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re | |
| return | ||
|
|
||
| # For notifications and responses only, return 202 Accepted | ||
| if not isinstance(message, JSONRPCRequest): # pragma: no cover | ||
| if not isinstance(message, JSONRPCRequest): | ||
| # Create response object and send it | ||
| response = self._create_json_response( | ||
| None, | ||
|
|
@@ -561,14 +561,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re | |
| await response(scope, receive, send) | ||
| finally: | ||
| await self._clean_up_memory_streams(request_id) | ||
| else: # pragma: no cover | ||
| else: | ||
| # Create SSE stream | ||
| sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) | ||
|
|
||
| # Store writer reference so close_sse_stream() can close it | ||
| self._sse_stream_writers[request_id] = sse_stream_writer | ||
|
|
||
| async def sse_writer(): | ||
| async def sse_writer(): # pragma: lax no cover | ||
| # Get the request ID from the incoming request message | ||
| try: | ||
| async with sse_stream_writer, request_stream_reader: | ||
|
|
@@ -617,11 +617,12 @@ async def sse_writer(): | |
| # Then send the message to be processed by the server | ||
| session_message = self._create_session_message(message, request, request_id, protocol_version) | ||
| await writer.send(session_message) | ||
| except Exception: | ||
| except Exception: # pragma: no cover | ||
| logger.exception("SSE response error") | ||
| await sse_stream_writer.aclose() | ||
| await sse_stream_reader.aclose() | ||
| await self._clean_up_memory_streams(request_id) | ||
| finally: | ||
| await sse_stream_reader.aclose() | ||
|
|
||
| except Exception as err: # pragma: no cover | ||
| logger.exception("Error handling POST request") | ||
|
|
@@ -635,33 +636,33 @@ async def sse_writer(): | |
| await writer.send(Exception(err)) | ||
| return | ||
|
|
||
| async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover | ||
| async def _handle_get_request(self, request: Request, send: Send) -> None: | ||
| """Handle GET request to establish SSE. | ||
|
|
||
| This allows the server to communicate to the client without the client | ||
| first sending data via HTTP POST. The server can send JSON-RPC requests | ||
| and notifications on this stream. | ||
| """ | ||
| writer = self._read_stream_writer | ||
| if writer is None: | ||
| if writer is None: # pragma: no cover | ||
| raise ValueError("No read stream writer available. Ensure connect() is called first.") | ||
|
|
||
| # Validate Accept header - must include text/event-stream | ||
| _, has_sse = self._check_accept_headers(request) | ||
|
|
||
| if not has_sse: | ||
| if not has_sse: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Not Acceptable: Client must accept text/event-stream", | ||
| HTTPStatus.NOT_ACCEPTABLE, | ||
| ) | ||
| await response(request.scope, request.receive, send) | ||
| return | ||
|
|
||
| if not await self._validate_request_headers(request, send): | ||
| if not await self._validate_request_headers(request, send): # pragma: no cover | ||
| return | ||
|
|
||
| # Handle resumability: check for Last-Event-ID header | ||
| if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): | ||
| if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover | ||
| await self._replay_events(last_event_id, request, send) | ||
| return | ||
|
|
||
|
|
@@ -675,7 +676,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr | |
| headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id | ||
|
|
||
| # Check if we already have an active GET stream | ||
| if GET_STREAM_KEY in self._request_streams: | ||
| if GET_STREAM_KEY in self._request_streams: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Conflict: Only one SSE stream is allowed per session", | ||
| HTTPStatus.CONFLICT, | ||
|
|
@@ -695,7 +696,7 @@ async def standalone_sse_writer(): | |
|
|
||
| async with sse_stream_writer, standalone_stream_reader: | ||
| # Process messages from the standalone stream | ||
| async for event_message in standalone_stream_reader: | ||
| async for event_message in standalone_stream_reader: # pragma: lax no cover | ||
| # For the standalone stream, we handle: | ||
| # - JSONRPCNotification (server sends notifications to client) | ||
| # - JSONRPCRequest (server sends requests to client) | ||
|
|
@@ -704,7 +705,7 @@ async def standalone_sse_writer(): | |
| # Send the message via SSE | ||
| event_data = self._create_event_data(event_message) | ||
| await sse_stream_writer.send(event_data) | ||
| except Exception: | ||
| except Exception: # pragma: no cover | ||
| logger.exception("Error in standalone SSE writer") | ||
| finally: | ||
| logger.debug("Closing standalone SSE writer") | ||
|
|
@@ -720,16 +721,17 @@ async def standalone_sse_writer(): | |
| try: | ||
| # This will send headers immediately and establish the SSE connection | ||
| await response(request.scope, request.receive, send) | ||
| except Exception: | ||
| except Exception: # pragma: lax no cover | ||
| logger.exception("Error in standalone SSE response") | ||
| await self._clean_up_memory_streams(GET_STREAM_KEY) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this moves up?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This stays in the |
||
| finally: | ||
| await sse_stream_writer.aclose() | ||
| await sse_stream_reader.aclose() | ||
| await self._clean_up_memory_streams(GET_STREAM_KEY) | ||
|
|
||
| async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: no cover | ||
| async def _handle_delete_request(self, request: Request, send: Send) -> None: | ||
| """Handle DELETE requests for explicit session termination.""" | ||
| # Validate session ID | ||
| if not self.mcp_session_id: | ||
| if not self.mcp_session_id: # pragma: no cover | ||
| # If no session ID set, return Method Not Allowed | ||
| response = self._create_error_response( | ||
| "Method Not Allowed: Session termination not supported", | ||
|
|
@@ -738,7 +740,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: # | |
| await response(request.scope, request.receive, send) | ||
| return | ||
|
|
||
| if not await self._validate_request_headers(request, send): | ||
| if not await self._validate_request_headers(request, send): # pragma: no cover | ||
| return | ||
|
|
||
| await self.terminate() | ||
|
|
@@ -796,24 +798,24 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non | |
| ) | ||
| await response(request.scope, request.receive, send) | ||
|
|
||
| async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: no cover | ||
| async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover | ||
| if not await self._validate_session(request, send): | ||
| return False | ||
| if not await self._validate_protocol_version(request, send): | ||
| return False | ||
| return True | ||
|
|
||
| async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: no cover | ||
| async def _validate_session(self, request: Request, send: Send) -> bool: | ||
| """Validate the session ID in the request.""" | ||
| if not self.mcp_session_id: | ||
| if not self.mcp_session_id: # pragma: no cover | ||
| # If we're not using session IDs, return True | ||
| return True | ||
|
|
||
| # Get the session ID from the request headers | ||
| request_session_id = self._get_session_id(request) | ||
|
|
||
| # If no session ID provided but required, return error | ||
| if not request_session_id: | ||
| if not request_session_id: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Bad Request: Missing session ID", | ||
| HTTPStatus.BAD_REQUEST, | ||
|
|
@@ -822,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag | |
| return False | ||
|
|
||
| # If session ID doesn't match, return error | ||
| if request_session_id != self.mcp_session_id: | ||
| if request_session_id != self.mcp_session_id: # pragma: no cover | ||
| response = self._create_error_response( | ||
| "Not Found: Invalid or expired session ID", | ||
| HTTPStatus.NOT_FOUND, | ||
|
|
@@ -832,17 +834,17 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag | |
|
|
||
| return True | ||
|
|
||
| async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: no cover | ||
| async def _validate_protocol_version(self, request: Request, send: Send) -> bool: | ||
| """Validate the protocol version header in the request.""" | ||
| # Get the protocol version from the request headers | ||
| protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) | ||
|
|
||
| # If no protocol version provided, assume default version | ||
| if protocol_version is None: | ||
| if protocol_version is None: # pragma: no cover | ||
| protocol_version = DEFAULT_NEGOTIATED_VERSION | ||
|
|
||
| # Check if the protocol version is supported | ||
| if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: | ||
| if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover | ||
| supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) | ||
| response = self._create_error_response( | ||
| f"Bad Request: Unsupported protocol version: {protocol_version}. " | ||
|
|
@@ -1004,10 +1006,7 @@ async def message_router(): | |
| try: | ||
| # Send both the message and the event ID | ||
| await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) | ||
| except ( # pragma: no cover | ||
| anyio.BrokenResourceError, | ||
| anyio.ClosedResourceError, | ||
| ): | ||
| except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover | ||
| # Stream might be closed, remove from registry | ||
| self._request_streams.pop(request_stream_id, None) | ||
| else: # pragma: no cover | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,9 +5,12 @@ | |
| from unittest.mock import AsyncMock, patch | ||
|
|
||
| import anyio | ||
| import httpx | ||
| import pytest | ||
| from starlette.types import Message | ||
|
|
||
| from mcp import Client, types | ||
| from mcp.client.streamable_http import streamable_http_client | ||
| from mcp.server import streamable_http_manager | ||
| from mcp.server.lowlevel import Server | ||
| from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport | ||
|
|
@@ -313,3 +316,35 @@ async def mock_receive(): | |
| assert error_data["id"] == "server-error" | ||
| assert error_data["error"]["code"] == INVALID_REQUEST | ||
| assert error_data["error"]["message"] == "Session not found" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def reset_sse_app_status(): | ||
| # Needed for tests with sse-starlette < 3 | ||
| # https://github.com/sysid/sse-starlette/issues/59 | ||
| # https://github.com/sysid/sse-starlette/blob/v3.2.0/README.md#testing | ||
|
|
||
| from sse_starlette.sse import AppStatus | ||
|
|
||
| AppStatus.should_exit_event = None # pyright: ignore[reportAttributeAccessIssue] | ||
| yield | ||
| AppStatus.should_exit_event = None # pyright: ignore[reportAttributeAccessIssue] | ||
|
||
|
|
||
|
|
||
| @pytest.mark.anyio | ||
| async def test_e2e_streamable_http_server_cleanup(reset_sse_app_status: None): | ||
| host = "testserver" | ||
| app = Server("test-server") | ||
|
|
||
| @app.list_tools() | ||
| async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: | ||
| return types.ListToolsResult(tools=[]) | ||
|
|
||
| mcp_app = app.streamable_http_app(host=host) | ||
| async with ( | ||
| mcp_app.router.lifespan_context(mcp_app), | ||
| httpx.ASGITransport(mcp_app) as transport, | ||
| httpx.AsyncClient(transport=transport) as http_client, | ||
| Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client, | ||
| ): | ||
| await client.list_tools() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the reader moves to the finally but not the writer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On L569 it gets stored
Should I be closing it here too, or does it need to outlive this function?