Skip to content

Commit 2aecd41

Browse files
test(fastmcp): Use AsyncClient for SSE
1 parent 960d76c commit 2aecd41

4 files changed

Lines changed: 275 additions & 196 deletions

File tree

tests/conftest.py

Lines changed: 199 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import json
22
import os
3+
import asyncio
4+
from urllib.parse import urlparse, parse_qs
35
import socket
46
import warnings
57
import brotli
@@ -51,24 +53,33 @@
5153
from typing import TYPE_CHECKING
5254

5355
if TYPE_CHECKING:
54-
from typing import Optional
56+
from typing import Any, Callable, MutableMapping, Optional
5557
from collections.abc import Iterator
5658

5759
try:
58-
from anyio import create_memory_object_stream, create_task_group
60+
from anyio import create_memory_object_stream, create_task_group, EndOfStream
5961
from mcp.types import (
6062
JSONRPCMessage,
6163
JSONRPCNotification,
6264
JSONRPCRequest,
6365
)
6466
from mcp.shared.message import SessionMessage
67+
from httpx import ASGITransport, Request, Response, AsyncByteStream, AsyncClient
6568
except ImportError:
6669
create_memory_object_stream = None
6770
create_task_group = None
71+
EndOfStream = None
72+
6873
JSONRPCMessage = None
6974
JSONRPCRequest = None
7075
SessionMessage = None
7176

77+
ASGITransport = None
78+
Request = None
79+
Response = None
80+
AsyncByteStream = None
81+
AsyncClient = None
82+
7283

7384
SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"
7485

@@ -786,6 +797,192 @@ def inner(events):
786797
return inner
787798

788799

800+
@pytest.fixture()
801+
def json_rpc_sse(is_structured_content: bool = True):
802+
class StreamingASGITransport(ASGITransport):
803+
"""
804+
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
805+
tests involving SSE interactions to run in-process.
806+
"""
807+
808+
def __init__(
809+
self,
810+
app: "Callable",
811+
keep_sse_alive: "asyncio.Event",
812+
) -> None:
813+
self.keep_sse_alive = keep_sse_alive
814+
super().__init__(app)
815+
816+
async def handle_async_request(self, request: "Request") -> "Response":
817+
scope = {
818+
"type": "http",
819+
"method": request.method,
820+
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
821+
"path": request.url.path,
822+
"query_string": request.url.query,
823+
}
824+
825+
is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
826+
if not is_streaming_sse:
827+
return await super().handle_async_request(request)
828+
829+
request_body = b""
830+
if request.content:
831+
request_body = await request.aread()
832+
833+
body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore
834+
835+
async def receive() -> "dict[str, Any]":
836+
if self.keep_sse_alive.is_set():
837+
return {"type": "http.disconnect"}
838+
839+
await self.keep_sse_alive.wait() # Keep alive :)
840+
return {
841+
"type": "http.request",
842+
"body": request_body,
843+
"more_body": False,
844+
}
845+
846+
async def send(message: "MutableMapping[str, Any]") -> None:
847+
if message["type"] == "http.response.body":
848+
body = message.get("body", b"")
849+
more_body = message.get("more_body", False)
850+
851+
if body == b"" and not more_body:
852+
return
853+
854+
if body:
855+
await body_sender.send(body)
856+
857+
if not more_body:
858+
await body_sender.aclose()
859+
860+
async def run_app():
861+
await self.app(scope, receive, send)
862+
863+
class StreamingBodyStream(AsyncByteStream): # type: ignore
864+
def __init__(self, receiver, task):
865+
self.receiver = receiver
866+
self.task = task
867+
868+
async def __aiter__(self):
869+
try:
870+
async for chunk in self.receiver:
871+
yield chunk
872+
except EndOfStream: # type: ignore
873+
pass
874+
875+
stream = StreamingBodyStream(body_receiver, asyncio.create_task(run_app()))
876+
response = Response(status_code=200, headers=[], stream=stream) # type: ignore
877+
878+
return response
879+
880+
def parse_sse_data_package(sse_chunk):
881+
sse_text = sse_chunk.decode("utf-8")
882+
json_str = sse_text.split("data: ")[1]
883+
return json.loads(json_str)
884+
885+
async def inner(
886+
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
887+
):
888+
context = {}
889+
890+
stream_complete = asyncio.Event()
891+
endpoint_parsed = asyncio.Event()
892+
893+
# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
894+
async with AsyncClient( # type: ignore
895+
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
896+
base_url="http://test",
897+
) as client:
898+
899+
async def parse_stream():
900+
async with client.stream("GET", "/sse") as stream:
901+
# Read directly from stream.stream instead of aiter_bytes()
902+
async for chunk in stream.stream:
903+
if b"event: endpoint" in chunk:
904+
sse_text = chunk.decode("utf-8")
905+
url = sse_text.split("data: ")[1]
906+
907+
parsed = urlparse(url)
908+
query_params = parse_qs(parsed.query)
909+
context["session_id"] = query_params["session_id"][0]
910+
endpoint_parsed.set()
911+
continue
912+
913+
if (
914+
is_structured_content
915+
and b"event: message" in chunk
916+
and b"structuredContent" in chunk
917+
):
918+
context["response"] = parse_sse_data_package(chunk)
919+
stream_complete.set()
920+
break
921+
elif (
922+
"result" in parse_sse_data_package(chunk)
923+
and "content" in parse_sse_data_package(chunk)["result"]
924+
):
925+
context["response"] = parse_sse_data_package(chunk)
926+
stream_complete.set()
927+
break
928+
929+
task = asyncio.create_task(parse_stream())
930+
await endpoint_parsed.wait()
931+
932+
await client.post(
933+
f"/messages/?session_id={context['session_id']}",
934+
headers={
935+
"Content-Type": "application/json",
936+
},
937+
json={
938+
"jsonrpc": "2.0",
939+
"method": "initialize",
940+
"params": {
941+
"clientInfo": {"name": "test-client", "version": "1.0"},
942+
"protocolVersion": "2025-11-25",
943+
"capabilities": {},
944+
},
945+
"id": request_id,
946+
},
947+
)
948+
949+
# Notification response is mandatory.
950+
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
951+
await client.post(
952+
f"/messages/?session_id={context['session_id']}",
953+
headers={
954+
"Content-Type": "application/json",
955+
"mcp-session-id": context["session_id"],
956+
},
957+
json={
958+
"jsonrpc": "2.0",
959+
"method": "notifications/initialized",
960+
"params": {},
961+
},
962+
)
963+
964+
await client.post(
965+
f"/messages/?session_id={context['session_id']}",
966+
headers={
967+
"Content-Type": "application/json",
968+
"mcp-session-id": context["session_id"],
969+
},
970+
json={
971+
"jsonrpc": "2.0",
972+
"method": method,
973+
"params": params,
974+
"id": request_id,
975+
},
976+
)
977+
978+
await stream_complete.wait()
979+
keep_sse_alive.set()
980+
981+
return task, context["session_id"], context["response"]
982+
983+
return inner
984+
985+
789986
class MockServerRequestHandler(BaseHTTPRequestHandler):
790987
def do_GET(self): # noqa: N802
791988
# Process an HTTP GET request and return a response.

tests/integrations/fastmcp/test_fastmcp.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
accurate testing of the integration's behavior in real MCP Server scenarios.
2222
"""
2323

24+
import anyio
2425
import asyncio
2526
import json
2627
import pytest
@@ -39,9 +40,12 @@ async def __call__(self, *args, **kwargs):
3940
from sentry_sdk.consts import SPANDATA, OP
4041
from sentry_sdk.integrations.mcp import MCPIntegration
4142

43+
from mcp.server.lowlevel import Server
44+
from mcp.server.sse import SseServerTransport
4245
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4346

44-
from starlette.routing import Mount
47+
from starlette.routing import Mount, Route
48+
from starlette.responses import Response
4549
from starlette.applications import Starlette
4650

4751
# Try to import both FastMCP implementations
@@ -1029,8 +1033,11 @@ def test_tool_no_ctx(x: int) -> dict:
10291033
# =============================================================================
10301034

10311035

1036+
@pytest.mark.asyncio
10321037
@pytest.mark.parametrize("FastMCP", fastmcp_implementations, ids=fastmcp_ids)
1033-
def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP):
1038+
async def test_fastmcp_sse_transport(
1039+
sentry_init, capture_events, FastMCP, json_rpc_sse
1040+
):
10341041
"""Test that FastMCP correctly detects SSE transport"""
10351042
sentry_init(
10361043
integrations=[MCPIntegration()],
@@ -1039,25 +1046,81 @@ def test_fastmcp_sse_transport(sentry_init, capture_events, FastMCP):
10391046
events = capture_events()
10401047

10411048
mcp = FastMCP("Test Server")
1049+
sse = SseServerTransport("/messages/")
10421050

1043-
# Set up mock request context with SSE transport
1044-
if request_ctx is not None:
1045-
mock_ctx = MockRequestContext(
1046-
request_id="req-sse", session_id="session-sse-123", transport="sse"
1047-
)
1048-
request_ctx.set(mock_ctx)
1051+
sse_connection_closed = asyncio.Event()
1052+
1053+
async def handle_sse(request):
1054+
async with sse.connect_sse(
1055+
request.scope, request.receive, request._send
1056+
) as streams:
1057+
async with anyio.create_task_group() as tg:
1058+
1059+
async def run_server():
1060+
await mcp._mcp_server.run(
1061+
streams[0],
1062+
streams[1],
1063+
mcp._mcp_server.create_initialization_options(),
1064+
)
1065+
1066+
tg.start_soon(run_server)
1067+
1068+
sse_connection_closed.set()
1069+
return Response()
1070+
1071+
app = Starlette(
1072+
routes=[
1073+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
1074+
Mount("/messages/", app=sse.handle_post_message),
1075+
],
1076+
)
10491077

10501078
@mcp.tool()
10511079
def sse_tool(value: str) -> dict:
10521080
"""Tool for SSE transport test"""
10531081
return {"message": f"Received: {value}"}
10541082

1055-
with start_transaction(name="fastmcp tx"):
1056-
result = call_tool_through_mcp(mcp, "sse_tool", {"value": "hello"})
1083+
keep_sse_alive = asyncio.Event()
1084+
app_task, _, result = await json_rpc_sse(
1085+
app,
1086+
method="tools/call",
1087+
params={
1088+
"name": "sse_tool",
1089+
"arguments": {"value": "hello"},
1090+
},
1091+
request_id="req-sse",
1092+
keep_sse_alive=keep_sse_alive,
1093+
)
10571094

1058-
assert result == {"message": "Received: hello"}
1095+
await sse_connection_closed.wait()
1096+
await app_task
10591097

1060-
(tx,) = events
1098+
if (
1099+
isinstance(mcp, StandaloneFastMCP)
1100+
and FASTMCP_VERSION is not None
1101+
and FASTMCP_VERSION.startswith("2")
1102+
):
1103+
assert result["result"]["content"][0]["text"] == json.dumps(
1104+
{"message": "Received: hello"}, separators=(",", ":")
1105+
)
1106+
elif (
1107+
isinstance(mcp, StandaloneFastMCP) and FASTMCP_VERSION is not None
1108+
): # Checking for None is not precise.
1109+
assert result["result"]["content"][0]["text"] == json.dumps(
1110+
{"message": "Received: hello"}
1111+
)
1112+
else:
1113+
assert result["result"]["content"][0]["text"] == json.dumps(
1114+
{"message": "Received: hello"}, indent=2
1115+
)
1116+
1117+
transactions = [
1118+
event
1119+
for event in events
1120+
if event["type"] == "transaction" and event["transaction"] == "/sse"
1121+
]
1122+
assert len(transactions) == 1
1123+
tx = transactions[0]
10611124

10621125
# Find MCP spans
10631126
mcp_spans = [s for s in tx["spans"] if s["op"] == OP.MCP_SERVER]

0 commit comments

Comments
 (0)