-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathsse.py
More file actions
221 lines (187 loc) · 10.1 KB
/
sse.py
File metadata and controls
221 lines (187 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import logging
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse, urlunparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError
from mcp import types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
query_params = parse_qs(urlparse(endpoint_url).query)
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
def _resolve_endpoint_url(base_url: str, endpoint: str) -> str:
"""Resolve an endpoint URL, preserving any reverse proxy/API gateway path prefix.
When an MCP server sits behind a reverse proxy or API gateway that adds a
path prefix (e.g., ``/gateway``), the server's endpoint events contain paths
without that prefix. Standard ``urljoin`` drops the base URL's path prefix
for absolute paths (starting with ``/``). This function detects and
preserves such prefixes.
Example::
>>> _resolve_endpoint_url(
... "https://host/gateway/v1/sse",
... "/v1/messages/?session_id=abc",
... )
'https://host/gateway/v1/messages/?session_id=abc'
"""
parsed_ep = urlparse(endpoint)
# Full URL — use as-is
if parsed_ep.scheme:
return endpoint
# Relative path (no leading /) — urljoin handles correctly
if not endpoint.startswith("/"):
return urljoin(base_url, endpoint)
# For absolute paths, detect and preserve any gateway prefix.
# Strategy: find the first path segment of the endpoint inside the base URL
# path. If it appears at a position > 0, everything before it is the
# gateway prefix that must be preserved.
parsed_base = urlparse(base_url)
base_path = parsed_base.path
ep_path = parsed_ep.path
ep_segments = [s for s in ep_path.split("/") if s]
if ep_segments:
first_seg = "/" + ep_segments[0]
idx = base_path.find(first_seg + "/")
if idx < 0 and base_path.endswith(first_seg):
idx = len(base_path) - len(first_seg)
if idx > 0:
prefix = base_path[:idx]
return urlunparse(
(
parsed_base.scheme,
parsed_base.netloc,
prefix + ep_path,
"",
parsed_ep.query,
"",
)
)
# No prefix detected — fall back to standard resolution
return urljoin(base_url, endpoint)
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 300.0,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
on_session_created: Callable[[str], None] | None = None,
):
"""Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations (in seconds).
sse_read_timeout: Timeout for SSE read operations (in seconds).
httpx_client_factory: Factory function for creating the HTTPX client.
auth: Optional HTTPX authentication handler.
on_session_created: Optional callback invoked with the session ID when received.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = _resolve_endpoint_url(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover
if on_session_created:
session_id = _extract_session_id_from_endpoint(endpoint_url)
if session_id:
on_session_created(session_id)
task_status.started(endpoint_url)
case "message":
# Skip empty data (keep-alive pings)
if not sse.data:
continue
try:
message = types.jsonrpc_message_adapter.validate_json(
sse.data, by_name=False
)
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
continue # pragma: no cover
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: lax no cover
logger.exception("Encountered SSE exception")
raise sse_exc
except Exception as exc: # pragma: lax no cover
logger.exception("Error in sse_reader")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_unset=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()