-
Notifications
You must be signed in to change notification settings - Fork 610
test(fastmcp): Use AsyncClient for SSE
#5400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 54 commits
a8372b0
87a2e26
70edb6d
c238f84
f1362bd
34dba91
46659c5
ac8e6e4
2766b40
d3afb38
270636d
960d76c
2aecd41
8c9aa86
6579935
8aa6363
30e9ec3
aaf69c3
0ab862c
c675fb9
8def7e3
1380f56
0e37050
8749ad3
45c9554
5976135
317c002
9097b69
7703379
1133d9e
0bd14e3
ef33e1f
4715afa
2df9629
7db7680
d6c2fa5
eaf4230
7484fd8
7c9d602
0f47f06
6bff4e1
1bf6876
fd2cc42
771f60e
8f57fcd
cc4a19d
ef06a2d
955d525
6b81f86
a16e638
0103986
8d6744c
e731e99
e554441
99571ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| import json | ||
| import os | ||
| import asyncio | ||
| from urllib.parse import urlparse, parse_qs | ||
| import socket | ||
| import warnings | ||
| import brotli | ||
|
|
@@ -51,25 +53,40 @@ | |
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing import Optional | ||
| from typing import Any, Callable, MutableMapping, Optional | ||
| from collections.abc import Iterator | ||
|
|
||
| try: | ||
| from anyio import create_memory_object_stream, create_task_group | ||
| from anyio import create_memory_object_stream, create_task_group, EndOfStream | ||
| from mcp.types import ( | ||
| JSONRPCMessage, | ||
| JSONRPCNotification, | ||
| JSONRPCRequest, | ||
| ) | ||
| from mcp.shared.message import SessionMessage | ||
| from httpx import ( | ||
| ASGITransport, | ||
| Request as HttpxRequest, | ||
| Response as HttpxResponse, | ||
| AsyncByteStream, | ||
| AsyncClient, | ||
| ) | ||
| except ImportError: | ||
| create_memory_object_stream = None | ||
| create_task_group = None | ||
| EndOfStream = None | ||
|
|
||
| JSONRPCMessage = None | ||
| JSONRPCNotification = None | ||
| JSONRPCRequest = None | ||
| SessionMessage = None | ||
|
|
||
| ASGITransport = None | ||
| HttpxRequest = None | ||
| HttpxResponse = None | ||
| AsyncByteStream = None | ||
| AsyncClient = None | ||
|
|
||
|
|
||
| SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json" | ||
|
|
||
|
|
@@ -787,6 +804,194 @@ def inner(events): | |
| return inner | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def json_rpc_sse(is_structured_content: bool = True): | ||
| class StreamingASGITransport(ASGITransport): | ||
|
sentry[bot] marked this conversation as resolved.
Outdated
|
||
| """ | ||
| Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing | ||
| tests involving SSE interactions to run in-process. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| app: "Callable", | ||
| keep_sse_alive: "asyncio.Event", | ||
| ) -> None: | ||
| self.keep_sse_alive = keep_sse_alive | ||
| super().__init__(app) | ||
|
|
||
| async def handle_async_request( | ||
| self, request: "HttpxRequest" | ||
| ) -> "HttpxResponse": | ||
| scope = { | ||
| "type": "http", | ||
| "method": request.method, | ||
| "headers": [(k.lower(), v) for (k, v) in request.headers.raw], | ||
| "path": request.url.path, | ||
| "query_string": request.url.query, | ||
| } | ||
|
|
||
| is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse" | ||
| if not is_streaming_sse: | ||
| return await super().handle_async_request(request) | ||
|
|
||
| request_body = b"" | ||
| if request.content: | ||
| request_body = await request.aread() | ||
|
|
||
| body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: The test fixture Suggested FixRevert the call at Prompt for AI Agent |
||
|
|
||
| async def receive() -> "dict[str, Any]": | ||
| if self.keep_sse_alive.is_set(): | ||
| return {"type": "http.disconnect"} | ||
|
|
||
| await self.keep_sse_alive.wait() # Keep alive :) | ||
| return { | ||
| "type": "http.request", | ||
| "body": request_body, | ||
| "more_body": False, | ||
| } | ||
|
|
||
| async def send(message: "MutableMapping[str, Any]") -> None: | ||
| if message["type"] == "http.response.body": | ||
| body = message.get("body", b"") | ||
| more_body = message.get("more_body", False) | ||
|
|
||
| if body == b"" and not more_body: | ||
| return | ||
|
|
||
| if body: | ||
| await body_sender.send(body) | ||
|
|
||
| if not more_body: | ||
| await body_sender.aclose() | ||
|
|
||
| async def run_app(): | ||
| await self.app(scope, receive, send) | ||
|
|
||
| class StreamingBodyStream(AsyncByteStream): # type: ignore | ||
| def __init__(self, receiver): | ||
| self.receiver = receiver | ||
|
|
||
| async def __aiter__(self): | ||
| try: | ||
| async for chunk in self.receiver: | ||
| yield chunk | ||
| except EndOfStream: # type: ignore | ||
| pass | ||
|
|
||
| stream = StreamingBodyStream(body_receiver) | ||
| response = HttpxResponse(status_code=200, headers=[], stream=stream) # type: ignore | ||
|
|
||
| asyncio.create_task(run_app()) | ||
| return response | ||
|
|
||
| def parse_sse_data_package(sse_chunk): | ||
| sse_text = sse_chunk.decode("utf-8") | ||
| json_str = sse_text.split("data: ")[1] | ||
| return json.loads(json_str) | ||
|
|
||
| async def inner( | ||
| app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event" | ||
| ): | ||
| context = {} | ||
|
|
||
| stream_complete = asyncio.Event() | ||
| endpoint_parsed = asyncio.Event() | ||
|
|
||
| # https://github.com/Kludex/starlette/issues/104#issuecomment-729087925 | ||
| async with AsyncClient( # type: ignore | ||
| transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive), | ||
| base_url="http://test", | ||
| ) as client: | ||
|
|
||
| async def parse_stream(): | ||
| async with client.stream("GET", "/sse") as stream: | ||
| # Read directly from stream.stream instead of aiter_bytes() | ||
| async for chunk in stream.stream: | ||
| if b"event: endpoint" in chunk: | ||
| sse_text = chunk.decode("utf-8") | ||
| url = sse_text.split("data: ")[1] | ||
|
|
||
| parsed = urlparse(url) | ||
| query_params = parse_qs(parsed.query) | ||
| context["session_id"] = query_params["session_id"][0] | ||
| endpoint_parsed.set() | ||
| continue | ||
|
|
||
| if ( | ||
| is_structured_content | ||
| and b"event: message" in chunk | ||
| and b"structuredContent" in chunk | ||
| ): | ||
| context["response"] = parse_sse_data_package(chunk) | ||
| break | ||
| elif ( | ||
| "result" in parse_sse_data_package(chunk) | ||
| and "content" in parse_sse_data_package(chunk)["result"] | ||
| ): | ||
| context["response"] = parse_sse_data_package(chunk) | ||
| break | ||
|
alexander-alderman-webb marked this conversation as resolved.
alexander-alderman-webb marked this conversation as resolved.
alexander-alderman-webb marked this conversation as resolved.
|
||
|
|
||
| stream_complete.set() | ||
|
|
||
| task = asyncio.create_task(parse_stream()) | ||
| await endpoint_parsed.wait() | ||
|
|
||
| await client.post( | ||
| f"/messages/?session_id={context['session_id']}", | ||
| headers={ | ||
| "Content-Type": "application/json", | ||
| }, | ||
| json={ | ||
| "jsonrpc": "2.0", | ||
| "method": "initialize", | ||
| "params": { | ||
| "clientInfo": {"name": "test-client", "version": "1.0"}, | ||
| "protocolVersion": "2025-11-25", | ||
| "capabilities": {}, | ||
| }, | ||
| "id": request_id, | ||
| }, | ||
| ) | ||
|
|
||
| # Notification response is mandatory. | ||
| # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle | ||
| await client.post( | ||
| f"/messages/?session_id={context['session_id']}", | ||
| headers={ | ||
| "Content-Type": "application/json", | ||
| "mcp-session-id": context["session_id"], | ||
| }, | ||
| json={ | ||
| "jsonrpc": "2.0", | ||
| "method": "notifications/initialized", | ||
| "params": {}, | ||
| }, | ||
| ) | ||
|
|
||
| await client.post( | ||
| f"/messages/?session_id={context['session_id']}", | ||
| headers={ | ||
| "Content-Type": "application/json", | ||
| "mcp-session-id": context["session_id"], | ||
| }, | ||
| json={ | ||
| "jsonrpc": "2.0", | ||
| "method": method, | ||
| "params": params, | ||
| "id": request_id, | ||
| }, | ||
| ) | ||
|
|
||
| await stream_complete.wait() | ||
| keep_sse_alive.set() | ||
|
|
||
| return task, context["session_id"], context["response"] | ||
|
alexander-alderman-webb marked this conversation as resolved.
|
||
|
|
||
| return inner | ||
|
|
||
|
|
||
| class MockServerRequestHandler(BaseHTTPRequestHandler): | ||
| def do_GET(self): # noqa: N802 | ||
| # Process an HTTP GET request and return a response. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.