|
1 | 1 | import json |
2 | 2 | import os |
| 3 | +import asyncio |
| 4 | +from urllib.parse import urlparse, parse_qs |
3 | 5 | import socket |
4 | 6 | import warnings |
5 | 7 | import brotli |
|
51 | 53 | from typing import TYPE_CHECKING |
52 | 54 |
|
53 | 55 | if TYPE_CHECKING: |
54 | | - from typing import Optional |
| 56 | + from typing import Any, Callable, MutableMapping, Optional |
55 | 57 | from collections.abc import Iterator |
56 | 58 |
|
57 | 59 | try: |
58 | | - from anyio import create_memory_object_stream, create_task_group |
| 60 | + from anyio import create_memory_object_stream, create_task_group, EndOfStream |
59 | 61 | from mcp.types import ( |
60 | 62 | JSONRPCMessage, |
61 | 63 | JSONRPCNotification, |
62 | 64 | JSONRPCRequest, |
63 | 65 | ) |
64 | 66 | from mcp.shared.message import SessionMessage |
| 67 | + from httpx import ( |
| 68 | + ASGITransport, |
| 69 | + Request as HttpxRequest, |
| 70 | + Response as HttpxResponse, |
| 71 | + AsyncByteStream, |
| 72 | + AsyncClient, |
| 73 | + ) |
65 | 74 | except ImportError: |
66 | 75 | create_memory_object_stream = None |
67 | 76 | create_task_group = None |
| 77 | + EndOfStream = None |
| 78 | + |
68 | 79 | JSONRPCMessage = None |
69 | 80 | JSONRPCNotification = None |
70 | 81 | JSONRPCRequest = None |
71 | 82 | SessionMessage = None |
72 | 83 |
|
| 84 | + ASGITransport = None |
| 85 | + HttpxRequest = None |
| 86 | + HttpxResponse = None |
| 87 | + AsyncByteStream = None |
| 88 | + AsyncClient = None |
| 89 | + |
73 | 90 |
|
74 | 91 | SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json" |
75 | 92 |
|
@@ -787,6 +804,190 @@ def inner(events): |
787 | 804 | return inner |
788 | 805 |
|
789 | 806 |
|
| 807 | +@pytest.fixture() |
| 808 | +def json_rpc_sse(): |
| 809 | + class StreamingASGITransport(ASGITransport): |
| 810 | + """ |
| 811 | + Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing |
| 812 | + tests involving SSE interactions to run in-process. |
| 813 | + """ |
| 814 | + |
| 815 | + def __init__( |
| 816 | + self, |
| 817 | + app: "Callable", |
| 818 | + keep_sse_alive: "asyncio.Event", |
| 819 | + ) -> None: |
| 820 | + self.keep_sse_alive = keep_sse_alive |
| 821 | + super().__init__(app) |
| 822 | + |
| 823 | + async def handle_async_request( |
| 824 | + self, request: "HttpxRequest" |
| 825 | + ) -> "HttpxResponse": |
| 826 | + scope = { |
| 827 | + "type": "http", |
| 828 | + "method": request.method, |
| 829 | + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], |
| 830 | + "path": request.url.path, |
| 831 | + "query_string": request.url.query, |
| 832 | + } |
| 833 | + |
| 834 | + is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse" |
| 835 | + if not is_streaming_sse: |
| 836 | + return await super().handle_async_request(request) |
| 837 | + |
| 838 | + request_body = b"" |
| 839 | + if request.content: |
| 840 | + request_body = await request.aread() |
| 841 | + |
| 842 | + body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore |
| 843 | + |
| 844 | + async def receive() -> "dict[str, Any]": |
| 845 | + if self.keep_sse_alive.is_set(): |
| 846 | + return {"type": "http.disconnect"} |
| 847 | + |
| 848 | + await self.keep_sse_alive.wait() # Keep alive :) |
| 849 | + return { |
| 850 | + "type": "http.request", |
| 851 | + "body": request_body, |
| 852 | + "more_body": False, |
| 853 | + } |
| 854 | + |
| 855 | + async def send(message: "MutableMapping[str, Any]") -> None: |
| 856 | + if message["type"] == "http.response.body": |
| 857 | + body = message.get("body", b"") |
| 858 | + more_body = message.get("more_body", False) |
| 859 | + |
| 860 | + if body == b"" and not more_body: |
| 861 | + return |
| 862 | + |
| 863 | + if body: |
| 864 | + await body_sender.send(body) |
| 865 | + |
| 866 | + if not more_body: |
| 867 | + await body_sender.aclose() |
| 868 | + |
| 869 | + async def run_app(): |
| 870 | + await self.app(scope, receive, send) |
| 871 | + |
| 872 | + class StreamingBodyStream(AsyncByteStream): # type: ignore |
| 873 | + def __init__(self, receiver): |
| 874 | + self.receiver = receiver |
| 875 | + |
| 876 | + async def __aiter__(self): |
| 877 | + try: |
| 878 | + async for chunk in self.receiver: |
| 879 | + yield chunk |
| 880 | + except EndOfStream: # type: ignore |
| 881 | + pass |
| 882 | + |
| 883 | + stream = StreamingBodyStream(body_receiver) |
| 884 | + response = HttpxResponse(status_code=200, headers=[], stream=stream) # type: ignore |
| 885 | + |
| 886 | + asyncio.create_task(run_app()) |
| 887 | + return response |
| 888 | + |
| 889 | + def parse_sse_data_package(sse_chunk): |
| 890 | + sse_text = sse_chunk.decode("utf-8") |
| 891 | + json_str = sse_text.split("data: ")[1] |
| 892 | + return json.loads(json_str) |
| 893 | + |
| 894 | + async def inner( |
| 895 | + app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event" |
| 896 | + ): |
| 897 | + context = {} |
| 898 | + |
| 899 | + stream_complete = asyncio.Event() |
| 900 | + endpoint_parsed = asyncio.Event() |
| 901 | + |
| 902 | + # https://github.com/Kludex/starlette/issues/104#issuecomment-729087925 |
| 903 | + async with AsyncClient( # type: ignore |
| 904 | + transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive), |
| 905 | + base_url="http://test", |
| 906 | + ) as client: |
| 907 | + |
| 908 | + async def parse_stream(): |
| 909 | + async with client.stream("GET", "/sse") as stream: |
| 910 | + # Read directly from stream.stream instead of aiter_bytes() |
| 911 | + async for chunk in stream.stream: |
| 912 | + if b"event: endpoint" in chunk: |
| 913 | + sse_text = chunk.decode("utf-8") |
| 914 | + url = sse_text.split("data: ")[1] |
| 915 | + |
| 916 | + parsed = urlparse(url) |
| 917 | + query_params = parse_qs(parsed.query) |
| 918 | + context["session_id"] = query_params["session_id"][0] |
| 919 | + endpoint_parsed.set() |
| 920 | + continue |
| 921 | + |
| 922 | + if b"event: message" in chunk and b"structuredContent" in chunk: |
| 923 | + context["response"] = parse_sse_data_package(chunk) |
| 924 | + break |
| 925 | + elif ( |
| 926 | + "result" in parse_sse_data_package(chunk) |
| 927 | + and "content" in parse_sse_data_package(chunk)["result"] |
| 928 | + ): |
| 929 | + context["response"] = parse_sse_data_package(chunk) |
| 930 | + break |
| 931 | + |
| 932 | + stream_complete.set() |
| 933 | + |
| 934 | + task = asyncio.create_task(parse_stream()) |
| 935 | + await endpoint_parsed.wait() |
| 936 | + |
| 937 | + await client.post( |
| 938 | + f"/messages/?session_id={context['session_id']}", |
| 939 | + headers={ |
| 940 | + "Content-Type": "application/json", |
| 941 | + }, |
| 942 | + json={ |
| 943 | + "jsonrpc": "2.0", |
| 944 | + "method": "initialize", |
| 945 | + "params": { |
| 946 | + "clientInfo": {"name": "test-client", "version": "1.0"}, |
| 947 | + "protocolVersion": "2025-11-25", |
| 948 | + "capabilities": {}, |
| 949 | + }, |
| 950 | + "id": request_id, |
| 951 | + }, |
| 952 | + ) |
| 953 | + |
| 954 | + # Notification response is mandatory. |
| 955 | + # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle |
| 956 | + await client.post( |
| 957 | + f"/messages/?session_id={context['session_id']}", |
| 958 | + headers={ |
| 959 | + "Content-Type": "application/json", |
| 960 | + "mcp-session-id": context["session_id"], |
| 961 | + }, |
| 962 | + json={ |
| 963 | + "jsonrpc": "2.0", |
| 964 | + "method": "notifications/initialized", |
| 965 | + "params": {}, |
| 966 | + }, |
| 967 | + ) |
| 968 | + |
| 969 | + await client.post( |
| 970 | + f"/messages/?session_id={context['session_id']}", |
| 971 | + headers={ |
| 972 | + "Content-Type": "application/json", |
| 973 | + "mcp-session-id": context["session_id"], |
| 974 | + }, |
| 975 | + json={ |
| 976 | + "jsonrpc": "2.0", |
| 977 | + "method": method, |
| 978 | + "params": params, |
| 979 | + "id": request_id, |
| 980 | + }, |
| 981 | + ) |
| 982 | + |
| 983 | + await stream_complete.wait() |
| 984 | + keep_sse_alive.set() |
| 985 | + |
| 986 | + return task, context["session_id"], context["response"] |
| 987 | + |
| 988 | + return inner |
| 989 | + |
| 990 | + |
790 | 991 | class MockServerRequestHandler(BaseHTTPRequestHandler): |
791 | 992 | def do_GET(self): # noqa: N802 |
792 | 993 | # Process an HTTP GET request and return a response. |
|
0 commit comments