Skip to content

Commit 20aff7a

Browse files
committed
Harden ClientSession enter/exit and fault delivery
- Unwind the entered task group when __aenter__ is cancelled while the dispatcher is starting, instead of abandoning its cancel scope - Deliver transport-level exceptions to message_handler concurrently and contained, like notifications, so a handler that awaits session I/O no longer deadlocks the read loop - Route related_request_id=0 correctly in send_notification (ids are opaque) - Document the dispatcher= constructor path in send_request's contract
1 parent 761500d commit 20aff7a

2 files changed

Lines changed: 249 additions & 33 deletions

File tree

src/mcp/client/session.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ class ClientSession:
129129
130130
Transport-level `Exception` items reach `message_handler` only when the
131131
session builds its own dispatcher from streams, where it wires the
132-
dispatcher's `on_stream_exception` itself.
132+
dispatcher's `on_stream_exception` itself. Faults are delivered
133+
concurrently in the session's task group, like notifications — never
134+
inline in the read loop — so the handler may await session I/O, and one
135+
that raises costs that delivery, not the connection.
133136
"""
134137

135138
def __init__(
@@ -174,7 +177,26 @@ def __init__(
174177
async def __aenter__(self) -> Self:
175178
self._task_group = anyio.create_task_group()
176179
await self._task_group.__aenter__()
177-
await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify)
180+
try:
181+
await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify)
182+
except BaseException:
183+
# A cancellation landing here (e.g. the caller wrapped connect in
184+
# `move_on_after`) would abandon the entered task group, and anyio
185+
# later raises "exited non-innermost cancel scope" instead of a
186+
# clean timeout. Unwind the group before propagating; cancelling
187+
# its scope first keeps __aexit__ from blocking under the
188+
# still-active cancellation.
189+
task_group = self._task_group
190+
self._task_group = None
191+
task_group.cancel_scope.cancel()
192+
# Shield the group's own scope (not a new one: scope exits must
193+
# stay LIFO) so a pending outer cancellation cannot re-fire
194+
# inside __aexit__; the join is prompt because the scope is
195+
# cancelled. The original exception then propagates from the
196+
# `raise`; a child error supersedes it, raised by __aexit__.
197+
task_group.cancel_scope.shield = True
198+
await task_group.__aexit__(None, None, None)
199+
raise
178200
return self
179201

180202
async def __aexit__(
@@ -209,8 +231,10 @@ async def send_request(
209231
210232
Raises:
211233
MCPError: The server responded with an error, or the read timeout
212-
elapsed, or the connection closed while waiting.
213-
RuntimeError: Called before entering the context manager.
234+
elapsed, or the connection closed while sending or waiting.
235+
RuntimeError: Called before entering the context manager. Raised
236+
by the stream-built dispatcher; a user-supplied `dispatcher=`
237+
may not enforce this.
214238
"""
215239
data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
216240
method: str = data["method"]
@@ -249,7 +273,8 @@ async def send_notification(
249273
) -> None:
250274
"""Send a one-way notification. Usable before entering the context manager."""
251275
data = notification.model_dump(by_alias=True, mode="json", exclude_none=True)
252-
if related_request_id and isinstance(self._dispatcher, JSONRPCDispatcher):
276+
# `is not None`, not truthiness: request ids are opaque and 0 is valid.
277+
if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher):
253278
await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id)
254279
else:
255280
await self._dispatcher.notify(data["method"], data.get("params"))
@@ -561,5 +586,21 @@ async def _on_notify(
561586
await self._message_handler(notification)
562587

563588
async def _on_stream_exception(self, exc: Exception) -> None:
564-
"""Forward transport-level faults (connection errors, parse errors) to message_handler."""
565-
await self._message_handler(exc)
589+
"""Spawn delivery of a transport-level fault (connection error, parse error) to message_handler.
590+
591+
The dispatcher awaits this observer inline in its read loop, so the
592+
handler must not run here: a slow handler would head-of-line block the
593+
session, and one that awaits session I/O (e.g. sends a ping) would
594+
deadlock against the parked loop. Spawn it instead, with the same
595+
containment notification deliveries get.
596+
"""
597+
# The dispatcher only runs inside the task group entered in
598+
# __aenter__, so the group is always live when it calls back here.
599+
assert self._task_group is not None
600+
self._task_group.start_soon(self._deliver_stream_exception, exc)
601+
602+
async def _deliver_stream_exception(self, exc: Exception) -> None:
603+
try:
604+
await self._message_handler(exc)
605+
except Exception:
606+
logger.exception("message_handler raised on transport exception")

0 commit comments

Comments
 (0)