Skip to content

Commit 3cc16ef

Browse files
committed
refactor: connect-first stream lifecycle for sse and streamable_http
Apply the websocket_client pattern from #2266 to the other two transports: establish the network connection first, create memory streams only after it succeeds, then own all four stream ends plus the task group in a single merged async with as the innermost scope. This eliminates the try/finally + four explicit aclose() calls. If the connection fails, no streams were ever created — nothing to clean up. The multi-CM async with unwinds in reverse order on exit, so tg.__aexit__ waits for cancelled tasks to finish before any stream end closes. streamable_http has one outer async with (the AsyncExitStack for the conditional httpx client), which is clean on all Python versions. sse has two unavoidable outer layers (httpx_client_factory feeds into aconnect_sse — data dependency, can't merge). On 3.14, coverage.py's static analysis sees a phantom branch on the innermost multi-CM line: each __aexit__ gets a POP_JUMP_IF_TRUE for 'did it suppress the exception?', which memory streams never do. One targeted pragma on the line we own, documented inline. Behavior change: sse_client's ConnectError is no longer wrapped in an ExceptionGroup, since the task group is never entered when the connection fails. Updated the regression test to match.
1 parent e1fd62e commit 3cc16ef

File tree

3 files changed

+144
-150
lines changed

3 files changed

+144
-150
lines changed

src/mcp/client/sse.py

Lines changed: 103 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -57,108 +57,107 @@ async def sse_client(
5757
write_stream: MemoryObjectSendStream[SessionMessage]
5858
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
5959

60-
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
61-
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
62-
63-
async with anyio.create_task_group() as tg:
64-
try:
65-
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
66-
async with httpx_client_factory(
67-
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
68-
) as client:
69-
async with aconnect_sse(
70-
client,
71-
"GET",
72-
url,
73-
) as event_source:
74-
event_source.response.raise_for_status()
75-
logger.debug("SSE connection established")
76-
77-
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
78-
try:
79-
async for sse in event_source.aiter_sse(): # pragma: no branch
80-
logger.debug(f"Received SSE event: {sse.event}")
81-
match sse.event:
82-
case "endpoint":
83-
endpoint_url = urljoin(url, sse.data)
84-
logger.debug(f"Received endpoint URL: {endpoint_url}")
85-
86-
url_parsed = urlparse(url)
87-
endpoint_parsed = urlparse(endpoint_url)
88-
if ( # pragma: no cover
89-
url_parsed.netloc != endpoint_parsed.netloc
90-
or url_parsed.scheme != endpoint_parsed.scheme
91-
):
92-
error_msg = ( # pragma: no cover
93-
f"Endpoint origin does not match connection origin: {endpoint_url}"
94-
)
95-
logger.error(error_msg) # pragma: no cover
96-
raise ValueError(error_msg) # pragma: no cover
97-
98-
if on_session_created:
99-
session_id = _extract_session_id_from_endpoint(endpoint_url)
100-
if session_id:
101-
on_session_created(session_id)
102-
103-
task_status.started(endpoint_url)
104-
105-
case "message":
106-
# Skip empty data (keep-alive pings)
107-
if not sse.data:
108-
continue
109-
try:
110-
message = types.jsonrpc_message_adapter.validate_json(
111-
sse.data, by_name=False
112-
)
113-
logger.debug(f"Received server message: {message}")
114-
except Exception as exc: # pragma: no cover
115-
logger.exception("Error parsing server message") # pragma: no cover
116-
await read_stream_writer.send(exc) # pragma: no cover
117-
continue # pragma: no cover
118-
119-
session_message = SessionMessage(message)
120-
await read_stream_writer.send(session_message)
121-
case _: # pragma: no cover
122-
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
123-
except SSEError as sse_exc: # pragma: lax no cover
124-
logger.exception("Encountered SSE exception")
125-
raise sse_exc
126-
except Exception as exc: # pragma: lax no cover
127-
logger.exception("Error in sse_reader")
128-
await read_stream_writer.send(exc)
129-
finally:
130-
await read_stream_writer.aclose()
131-
132-
async def post_writer(endpoint_url: str):
133-
try:
134-
async with write_stream_reader:
135-
async for session_message in write_stream_reader:
136-
logger.debug(f"Sending client message: {session_message}")
137-
response = await client.post(
138-
endpoint_url,
139-
json=session_message.message.model_dump(
140-
by_alias=True,
141-
mode="json",
142-
exclude_unset=True,
143-
),
60+
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
61+
async with httpx_client_factory(
62+
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
63+
) as client:
64+
async with aconnect_sse(
65+
client,
66+
"GET",
67+
url,
68+
) as event_source:
69+
event_source.response.raise_for_status()
70+
logger.debug("SSE connection established")
71+
72+
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
73+
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
74+
75+
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
76+
try:
77+
async for sse in event_source.aiter_sse(): # pragma: no branch
78+
logger.debug(f"Received SSE event: {sse.event}")
79+
match sse.event:
80+
case "endpoint":
81+
endpoint_url = urljoin(url, sse.data)
82+
logger.debug(f"Received endpoint URL: {endpoint_url}")
83+
84+
url_parsed = urlparse(url)
85+
endpoint_parsed = urlparse(endpoint_url)
86+
if ( # pragma: no cover
87+
url_parsed.netloc != endpoint_parsed.netloc
88+
or url_parsed.scheme != endpoint_parsed.scheme
89+
):
90+
error_msg = ( # pragma: no cover
91+
f"Endpoint origin does not match connection origin: {endpoint_url}"
14492
)
145-
response.raise_for_status()
146-
logger.debug(f"Client message sent successfully: {response.status_code}")
147-
except Exception: # pragma: lax no cover
148-
logger.exception("Error in post_writer")
149-
finally:
150-
await write_stream.aclose()
151-
152-
endpoint_url = await tg.start(sse_reader)
153-
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
154-
tg.start_soon(post_writer, endpoint_url)
155-
156-
try:
157-
yield read_stream, write_stream
158-
finally:
159-
tg.cancel_scope.cancel()
160-
finally:
161-
await read_stream_writer.aclose()
162-
await write_stream.aclose()
163-
await read_stream.aclose()
164-
await write_stream_reader.aclose()
93+
logger.error(error_msg) # pragma: no cover
94+
raise ValueError(error_msg) # pragma: no cover
95+
96+
if on_session_created:
97+
session_id = _extract_session_id_from_endpoint(endpoint_url)
98+
if session_id:
99+
on_session_created(session_id)
100+
101+
task_status.started(endpoint_url)
102+
103+
case "message":
104+
# Skip empty data (keep-alive pings)
105+
if not sse.data:
106+
continue
107+
try:
108+
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
109+
logger.debug(f"Received server message: {message}")
110+
except Exception as exc: # pragma: no cover
111+
logger.exception("Error parsing server message") # pragma: no cover
112+
await read_stream_writer.send(exc) # pragma: no cover
113+
continue # pragma: no cover
114+
115+
session_message = SessionMessage(message)
116+
await read_stream_writer.send(session_message)
117+
case _: # pragma: no cover
118+
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
119+
except SSEError as sse_exc: # pragma: lax no cover
120+
logger.exception("Encountered SSE exception")
121+
raise sse_exc
122+
except Exception as exc: # pragma: lax no cover
123+
logger.exception("Error in sse_reader")
124+
await read_stream_writer.send(exc)
125+
finally:
126+
await read_stream_writer.aclose()
127+
128+
async def post_writer(endpoint_url: str):
129+
try:
130+
async with write_stream_reader:
131+
async for session_message in write_stream_reader:
132+
logger.debug(f"Sending client message: {session_message}")
133+
response = await client.post(
134+
endpoint_url,
135+
json=session_message.message.model_dump(
136+
by_alias=True,
137+
mode="json",
138+
exclude_unset=True,
139+
),
140+
)
141+
response.raise_for_status()
142+
logger.debug(f"Client message sent successfully: {response.status_code}")
143+
except Exception: # pragma: lax no cover
144+
logger.exception("Error in post_writer")
145+
finally:
146+
await write_stream.aclose()
147+
148+
# On Python 3.14, coverage.py reports a phantom branch arc on this
149+
# line (->yield) when nested two async-with levels deep. The branch
150+
# is the unreachable "did __aexit__ suppress?" arm for memory streams.
151+
async with ( # pragma: no branch
152+
read_stream_writer,
153+
read_stream,
154+
write_stream,
155+
write_stream_reader,
156+
anyio.create_task_group() as tg,
157+
):
158+
endpoint_url = await tg.start(sse_reader)
159+
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
160+
tg.start_soon(post_writer, endpoint_url)
161+
162+
yield read_stream, write_stream
163+
tg.cancel_scope.cancel()

src/mcp/client/streamable_http.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,6 @@ async def streamable_http_client(
533533
Example:
534534
See examples/snippets/clients/ for usage patterns.
535535
"""
536-
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
537-
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
538-
539536
# Determine if we need to create and manage the client
540537
client_provided = http_client is not None
541538
client = http_client
@@ -546,36 +543,40 @@ async def streamable_http_client(
546543

547544
transport = StreamableHTTPTransport(url)
548545

549-
async with anyio.create_task_group() as tg:
550-
try:
551-
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
552-
553-
async with contextlib.AsyncExitStack() as stack:
554-
# Only manage client lifecycle if we created it
555-
if not client_provided:
556-
await stack.enter_async_context(client)
557-
558-
def start_get_stream() -> None:
559-
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
560-
561-
tg.start_soon(
562-
transport.post_writer,
563-
client,
564-
write_stream_reader,
565-
read_stream_writer,
566-
write_stream,
567-
start_get_stream,
568-
tg,
569-
)
546+
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
547+
548+
async with contextlib.AsyncExitStack() as stack:
549+
# Only manage client lifecycle if we created it
550+
if not client_provided:
551+
await stack.enter_async_context(client)
552+
553+
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
554+
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
555+
556+
async with (
557+
read_stream_writer,
558+
read_stream,
559+
write_stream,
560+
write_stream_reader,
561+
anyio.create_task_group() as tg,
562+
):
563+
564+
def start_get_stream() -> None:
565+
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
566+
567+
tg.start_soon(
568+
transport.post_writer,
569+
client,
570+
write_stream_reader,
571+
read_stream_writer,
572+
write_stream,
573+
start_get_stream,
574+
tg,
575+
)
570576

571-
try:
572-
yield read_stream, write_stream
573-
finally:
574-
if transport.session_id and terminate_on_close:
575-
await transport.terminate_session(client)
576-
tg.cancel_scope.cancel()
577-
finally:
578-
await read_stream_writer.aclose()
579-
await write_stream.aclose()
580-
await read_stream.aclose()
581-
await write_stream_reader.aclose()
577+
try:
578+
yield read_stream, write_stream
579+
finally:
580+
if transport.session_id and terminate_on_close:
581+
await transport.terminate_session(client)
582+
tg.cancel_scope.cancel()

tests/client/test_transport_stream_cleanup.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,17 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover
5858

5959
@pytest.mark.anyio
6060
async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None:
61-
"""sse_client must close all 4 stream ends when the connection fails.
61+
"""sse_client creates streams only after the SSE connection succeeds, so a
62+
ConnectError propagates directly with nothing to leak.
6263
63-
Before the fix, only read_stream_writer and write_stream were closed in
64-
the finally block. read_stream and write_stream_reader were leaked.
64+
Before the fix, streams were created before connecting and only 2 of 4 were
65+
closed in the finally block.
6566
"""
6667
with _assert_no_memory_stream_leak():
67-
# sse_client enters a task group BEFORE connecting, so anyio wraps the
68-
# ConnectError from aconnect_sse in an ExceptionGroup.
69-
with pytest.raises(Exception) as exc_info: # noqa: B017
68+
with pytest.raises(httpx.ConnectError):
7069
async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"):
7170
pytest.fail("should not reach here") # pragma: no cover
7271

73-
assert exc_info.group_contains(httpx.ConnectError)
74-
# exc_info holds the traceback → holds frame locals → keeps leaked
75-
# streams alive. Must drop it before gc.collect() can detect a leak.
76-
del exc_info
77-
7872

7973
@pytest.mark.anyio
8074
async def test_streamable_http_client_closes_all_streams_on_exit() -> None:

0 commit comments

Comments
 (0)