Skip to content

Commit 47616ac

Browse files
committed
Close the write stream only after the task-group join
run() entered both streams inside its own task group, so at teardown the write stream closed before in-flight handlers sent their final answers: the shutdown CONNECTION_CLOSED response was deterministically dropped on the EOF path and raced the close on the cancel path. The write stream's scope now wraps the task group, so scope exits order the join strictly before the close and teardown writes always land. The shutdown-delivery test becomes a real memory-stream pin, and the wedged-shutdown test's synthetic stream is replaced by a plain unread one.
1 parent 1672788 commit 47616ac

2 files changed

Lines changed: 47 additions & 60 deletions

File tree

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -384,31 +384,33 @@ async def run(
384384
`task_status.started()` fires once `send_raw_request` is usable.
385385
"""
386386
try:
387-
async with anyio.create_task_group() as tg:
388-
self._tg = tg
389-
self._running = True
390-
task_status.started()
391-
try:
392-
async with self._read_stream, self._write_stream:
393-
try:
394-
async for item in self._read_stream:
395-
# Duck-typed: only `ContextReceiveStream` carries the
396-
# sender's per-message contextvars snapshot.
397-
sender_ctx: contextvars.Context | None = getattr(
398-
self._read_stream, "last_context", None
399-
)
400-
await self._dispatch(item, on_request, on_notify, sender_ctx)
401-
except anyio.ClosedResourceError:
402-
# Receive end closed under us (stateless SHTTP teardown); same as EOF.
403-
logger.debug("read stream closed by transport; treating as EOF")
404-
# EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED.
405-
self._running = False
406-
self._closed = True
407-
self._fan_out_closed()
408-
finally:
409-
# Cancel in-flight handlers; otherwise the task-group join
410-
# waits on handlers whose callers are already gone.
411-
tg.cancel_scope.cancel()
387+
# LIFO exits: the write stream closes only after the task-group join, so teardown writes still land.
388+
async with self._write_stream:
389+
async with anyio.create_task_group() as tg:
390+
self._tg = tg
391+
self._running = True
392+
task_status.started()
393+
try:
394+
async with self._read_stream:
395+
try:
396+
async for item in self._read_stream:
397+
# Duck-typed: only `ContextReceiveStream` carries the
398+
# sender's per-message contextvars snapshot.
399+
sender_ctx: contextvars.Context | None = getattr(
400+
self._read_stream, "last_context", None
401+
)
402+
await self._dispatch(item, on_request, on_notify, sender_ctx)
403+
except anyio.ClosedResourceError:
404+
# Receive end closed under us (stateless SHTTP teardown); same as EOF.
405+
logger.debug("read stream closed by transport; treating as EOF")
406+
# EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED.
407+
self._running = False
408+
self._closed = True
409+
self._fan_out_closed()
410+
finally:
411+
# Cancel in-flight handlers; otherwise the task-group join
412+
# waits on handlers whose callers are already gone.
413+
tg.cancel_scope.cancel()
412414
finally:
413415
# Covers cancel/crash paths that skip the inline fan-out; idempotent.
414416
self._running = False

tests/shared/test_jsonrpc_dispatcher.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ async def drive() -> None:
306306

307307
@pytest.mark.anyio
308308
async def test_run_closes_write_stream_on_exit():
309-
"""run() enters both streams; the write end is released on EOF."""
309+
"""run() owns both streams; the write end is released once the EOF teardown completes."""
310310
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
311311
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
312312
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send)
@@ -819,29 +819,11 @@ async def test_shutdown_error_response_write_is_bounded_when_the_transport_is_we
819819
caplog: pytest.LogCaptureFixture,
820820
):
821821
"""Cancelling the task group hosting run() completes even when the shutdown error write wedges:
822-
only `_SHUTDOWN_WRITE_TIMEOUT` releases the join (SDK-defined). The fake stream is needed
823-
because run()'s teardown closes a memory stream, which would wake the blocked send."""
824-
825-
class WedgedWriteStream:
826-
async def send(self, item: SessionMessage) -> None:
827-
await anyio.sleep_forever()
828-
829-
async def aclose(self) -> None:
830-
raise NotImplementedError
831-
832-
async def __aenter__(self) -> "WedgedWriteStream":
833-
return self
834-
835-
async def __aexit__(
836-
self,
837-
exc_type: type[BaseException] | None,
838-
exc_val: BaseException | None,
839-
exc_tb: TracebackType | None,
840-
) -> bool | None:
841-
return None
842-
822+
only `_SHUTDOWN_WRITE_TIMEOUT` releases the join (SDK-defined). A 0-buffer stream nobody reads
823+
expresses the wedge: run() closes its write stream only after the join, so the send stays parked."""
843824
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1)
844-
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, WedgedWriteStream())
825+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0)
826+
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send)
845827
handler_started = anyio.Event()
846828

847829
async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
@@ -863,19 +845,19 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) ->
863845
await handler_started.wait()
864846
tg.cancel_scope.cancel()
865847
finally:
866-
c2s_send.close()
867-
c2s_recv.close()
848+
for s in (c2s_send, c2s_recv, s2c_send, s2c_recv):
849+
s.close()
868850
# The warning proves the bound (not a completed write) released the join.
869851
assert "shutdown error response for request" in caplog.text
870852

871853

872854
@pytest.mark.anyio
873855
async def test_shutdown_answers_in_flight_request_with_connection_closed():
874-
"""Cancelling run() answers a still-running request with CONNECTION_CLOSED (SDK-defined). The
875-
recording stream is needed because run()'s exit would close a memory stream before the shielded write lands."""
856+
"""Read-stream EOF answers a still-running request with CONNECTION_CLOSED (SDK-defined):
857+
run() keeps the write stream open until the task-group join, so the shielded teardown write lands."""
876858
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4)
877-
recording = RecordingWriteStream()
878-
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording)
859+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4)
860+
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send)
879861
handler_started = anyio.Event()
880862

881863
async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
@@ -892,13 +874,16 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) ->
892874
await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)))
893875
with anyio.fail_after(5):
894876
await handler_started.wait()
895-
tg.cancel_scope.cancel()
877+
c2s_send.close() # EOF: run() cancels the parked handler, which must still answer
878+
with anyio.fail_after(5):
879+
answer = await s2c_recv.receive()
880+
assert isinstance(answer, SessionMessage)
881+
assert answer.message == JSONRPCError(
882+
jsonrpc="2.0", id=1, error=ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
883+
)
896884
finally:
897-
c2s_send.close()
898-
c2s_recv.close()
899-
assert [m.message for m in recording.sent] == [
900-
JSONRPCError(jsonrpc="2.0", id=1, error=ErrorData(code=CONNECTION_CLOSED, message="Connection closed"))
901-
]
885+
for s in (c2s_send, c2s_recv, s2c_send, s2c_recv):
886+
s.close()
902887

903888

904889
@pytest.mark.anyio

0 commit comments

Comments
 (0)