Skip to content

Commit 1566e62

Browse files
test(fastmcp): Use AsyncClient for SSE (#5400)
1 parent 6a24fc9 commit 1566e62

File tree

4 files changed

+263
-221
lines changed

4 files changed

+263
-221
lines changed

tests/conftest.py

Lines changed: 203 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import json
22
import os
3+
import asyncio
4+
from urllib.parse import urlparse, parse_qs
35
import socket
46
import warnings
57
import brotli
@@ -51,25 +53,40 @@
5153
from typing import TYPE_CHECKING
5254

5355
if TYPE_CHECKING:
54-
from typing import Optional
56+
from typing import Any, Callable, MutableMapping, Optional
5557
from collections.abc import Iterator
5658

5759
try:
58-
from anyio import create_memory_object_stream, create_task_group
60+
from anyio import create_memory_object_stream, create_task_group, EndOfStream
5961
from mcp.types import (
6062
JSONRPCMessage,
6163
JSONRPCNotification,
6264
JSONRPCRequest,
6365
)
6466
from mcp.shared.message import SessionMessage
67+
from httpx import (
68+
ASGITransport,
69+
Request as HttpxRequest,
70+
Response as HttpxResponse,
71+
AsyncByteStream,
72+
AsyncClient,
73+
)
6574
except ImportError:
6675
create_memory_object_stream = None
6776
create_task_group = None
77+
EndOfStream = None
78+
6879
JSONRPCMessage = None
6980
JSONRPCNotification = None
7081
JSONRPCRequest = None
7182
SessionMessage = None
7283

84+
ASGITransport = None
85+
HttpxRequest = None
86+
HttpxResponse = None
87+
AsyncByteStream = None
88+
AsyncClient = None
89+
7390

7491
SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"
7592

@@ -787,6 +804,190 @@ def inner(events):
787804
return inner
788805

789806

807+
@pytest.fixture()
808+
def json_rpc_sse():
809+
class StreamingASGITransport(ASGITransport):
810+
"""
811+
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
812+
tests involving SSE interactions to run in-process.
813+
"""
814+
815+
def __init__(
816+
self,
817+
app: "Callable",
818+
keep_sse_alive: "asyncio.Event",
819+
) -> None:
820+
self.keep_sse_alive = keep_sse_alive
821+
super().__init__(app)
822+
823+
async def handle_async_request(
824+
self, request: "HttpxRequest"
825+
) -> "HttpxResponse":
826+
scope = {
827+
"type": "http",
828+
"method": request.method,
829+
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
830+
"path": request.url.path,
831+
"query_string": request.url.query,
832+
}
833+
834+
is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
835+
if not is_streaming_sse:
836+
return await super().handle_async_request(request)
837+
838+
request_body = b""
839+
if request.content:
840+
request_body = await request.aread()
841+
842+
body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore
843+
844+
async def receive() -> "dict[str, Any]":
845+
if self.keep_sse_alive.is_set():
846+
return {"type": "http.disconnect"}
847+
848+
await self.keep_sse_alive.wait() # Keep alive :)
849+
return {
850+
"type": "http.request",
851+
"body": request_body,
852+
"more_body": False,
853+
}
854+
855+
async def send(message: "MutableMapping[str, Any]") -> None:
856+
if message["type"] == "http.response.body":
857+
body = message.get("body", b"")
858+
more_body = message.get("more_body", False)
859+
860+
if body == b"" and not more_body:
861+
return
862+
863+
if body:
864+
await body_sender.send(body)
865+
866+
if not more_body:
867+
await body_sender.aclose()
868+
869+
async def run_app():
870+
await self.app(scope, receive, send)
871+
872+
class StreamingBodyStream(AsyncByteStream): # type: ignore
873+
def __init__(self, receiver):
874+
self.receiver = receiver
875+
876+
async def __aiter__(self):
877+
try:
878+
async for chunk in self.receiver:
879+
yield chunk
880+
except EndOfStream: # type: ignore
881+
pass
882+
883+
stream = StreamingBodyStream(body_receiver)
884+
response = HttpxResponse(status_code=200, headers=[], stream=stream) # type: ignore
885+
886+
asyncio.create_task(run_app())
887+
return response
888+
889+
def parse_sse_data_package(sse_chunk):
890+
sse_text = sse_chunk.decode("utf-8")
891+
json_str = sse_text.split("data: ")[1]
892+
return json.loads(json_str)
893+
894+
async def inner(
895+
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
896+
):
897+
context = {}
898+
899+
stream_complete = asyncio.Event()
900+
endpoint_parsed = asyncio.Event()
901+
902+
# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
903+
async with AsyncClient( # type: ignore
904+
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
905+
base_url="http://test",
906+
) as client:
907+
908+
async def parse_stream():
909+
async with client.stream("GET", "/sse") as stream:
910+
# Read directly from stream.stream instead of aiter_bytes()
911+
async for chunk in stream.stream:
912+
if b"event: endpoint" in chunk:
913+
sse_text = chunk.decode("utf-8")
914+
url = sse_text.split("data: ")[1]
915+
916+
parsed = urlparse(url)
917+
query_params = parse_qs(parsed.query)
918+
context["session_id"] = query_params["session_id"][0]
919+
endpoint_parsed.set()
920+
continue
921+
922+
if b"event: message" in chunk and b"structuredContent" in chunk:
923+
context["response"] = parse_sse_data_package(chunk)
924+
break
925+
elif (
926+
"result" in parse_sse_data_package(chunk)
927+
and "content" in parse_sse_data_package(chunk)["result"]
928+
):
929+
context["response"] = parse_sse_data_package(chunk)
930+
break
931+
932+
stream_complete.set()
933+
934+
task = asyncio.create_task(parse_stream())
935+
await endpoint_parsed.wait()
936+
937+
await client.post(
938+
f"/messages/?session_id={context['session_id']}",
939+
headers={
940+
"Content-Type": "application/json",
941+
},
942+
json={
943+
"jsonrpc": "2.0",
944+
"method": "initialize",
945+
"params": {
946+
"clientInfo": {"name": "test-client", "version": "1.0"},
947+
"protocolVersion": "2025-11-25",
948+
"capabilities": {},
949+
},
950+
"id": request_id,
951+
},
952+
)
953+
954+
# Notification response is mandatory.
955+
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
956+
await client.post(
957+
f"/messages/?session_id={context['session_id']}",
958+
headers={
959+
"Content-Type": "application/json",
960+
"mcp-session-id": context["session_id"],
961+
},
962+
json={
963+
"jsonrpc": "2.0",
964+
"method": "notifications/initialized",
965+
"params": {},
966+
},
967+
)
968+
969+
await client.post(
970+
f"/messages/?session_id={context['session_id']}",
971+
headers={
972+
"Content-Type": "application/json",
973+
"mcp-session-id": context["session_id"],
974+
},
975+
json={
976+
"jsonrpc": "2.0",
977+
"method": method,
978+
"params": params,
979+
"id": request_id,
980+
},
981+
)
982+
983+
await stream_complete.wait()
984+
keep_sse_alive.set()
985+
986+
return task, context["session_id"], context["response"]
987+
988+
return inner
989+
990+
790991
class MockServerRequestHandler(BaseHTTPRequestHandler):
791992
def do_GET(self): # noqa: N802
792993
# Process an HTTP GET request and return a response.

tests/integrations/fastmcp/test_fastmcp.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
accurate testing of the integration's behavior in real MCP Server scenarios.
2222
"""
2323

24+
import anyio
2425
import asyncio
2526
import json
2627
import pytest
@@ -39,9 +40,11 @@ async def __call__(self, *args, **kwargs):
3940
from sentry_sdk.consts import SPANDATA, OP
4041
from sentry_sdk.integrations.mcp import MCPIntegration
4142

43+
from mcp.server.sse import SseServerTransport
4244
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4345

44-
from starlette.routing import Mount
46+
from starlette.responses import Response
47+
from starlette.routing import Mount, Route
4548
from starlette.applications import Starlette
4649

4750
# Try to import both FastMCP implementations
@@ -254,34 +257,6 @@ def reset_request_ctx():
254257
pass
255258

256259

257-
class MockRequestContext:
258-
"""Mock MCP request context"""
259-
260-
def __init__(self, request_id=None, session_id=None, transport="stdio"):
261-
self.request_id = request_id
262-
if transport in ("http", "sse"):
263-
self.request = MockHTTPRequest(session_id, transport)
264-
else:
265-
self.request = None
266-
267-
268-
class MockHTTPRequest:
269-
"""Mock HTTP request for SSE/StreamableHTTP transport"""
270-
271-
def __init__(self, session_id=None, transport="http"):
272-
self.headers = {}
273-
self.query_params = {}
274-
275-
if transport == "sse":
276-
# SSE transport uses query parameter
277-
if session_id:
278-
self.query_params["session_id"] = session_id
279-
else:
280-
# StreamableHTTP transport uses header
281-
if session_id:
282-
self.headers["mcp-session-id"] = session_id
283-
284-
285260
# =============================================================================
286261
# Tool Handler Tests - Verifying Sentry Integration
287262
# =============================================================================
@@ -956,8 +931,11 @@ def test_tool_no_ctx(x: int) -> dict:
956931
# =============================================================================
957932

958933

934+
@pytest.mark.asyncio
959935
@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
960-
def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP):
936+
async def test_fastmcp_sse_transport(
937+
sentry_init, capture_events, FastMCP, json_rpc_sse
938+
):
961939
"""Test that FastMCP correctly detects SSE transport"""
962940
sentry_init(
963941
integrations=[MCPIntegration()],
@@ -966,25 +944,66 @@ def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP):
966944
events = capture_events()
967945

968946
mcp = FastMCP("Test Server")
947+
sse = SseServerTransport("/messages/")
969948

970-
# Set up mock request context with SSE transport
971-
if request_ctx is not None:
972-
mock_ctx = MockRequestContext(
973-
request_id="req-sse", session_id="session-sse-123", transport="sse"
974-
)
975-
request_ctx.set(mock_ctx)
949+
sse_connection_closed = asyncio.Event()
950+
951+
async def handle_sse(request):
952+
async with sse.connect_sse(
953+
request.scope, request.receive, request._send
954+
) as streams:
955+
async with anyio.create_task_group() as tg:
956+
957+
async def run_server():
958+
await mcp._mcp_server.run(
959+
streams[0],
960+
streams[1],
961+
mcp._mcp_server.create_initialization_options(),
962+
)
963+
964+
tg.start_soon(run_server)
965+
966+
sse_connection_closed.set()
967+
return Response()
968+
969+
app = Starlette(
970+
routes=[
971+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
972+
Mount("/messages/", app=sse.handle_post_message),
973+
],
974+
)
976975

977976
@mcp.tool()
978977
def sse_tool(value: str) -> dict:
979978
"""Tool for SSE transport test"""
980979
return {"message": f"Received: {value}"}
981980

982-
with start_transaction(name="fastmcp tx"):
983-
result = call_tool_through_mcp(mcp, "sse_tool", {"value": "hello"})
981+
keep_sse_alive = asyncio.Event()
982+
app_task, _, result = await json_rpc_sse(
983+
app,
984+
method="tools/call",
985+
params={
986+
"name": "sse_tool",
987+
"arguments": {"value": "hello"},
988+
},
989+
request_id="req-sse",
990+
keep_sse_alive=keep_sse_alive,
991+
)
984992

985-
assert result == {"message": "Received: hello"}
993+
await sse_connection_closed.wait()
994+
await app_task
986995

987-
(tx,) = events
996+
assert json.loads(result["result"]["content"][0]["text"]) == {
997+
"message": "Received: hello"
998+
}
999+
1000+
transactions = [
1001+
event
1002+
for event in events
1003+
if event["type"] == "transaction" and event["transaction"] == "/sse"
1004+
]
1005+
assert len(transactions) == 1
1006+
tx = transactions[0]
9881007

9891008
# Find MCP spans
9901009
mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]

0 commit comments

Comments
 (0)