Skip to content

Commit 1380ede

Browse files
committed
fix(stdio): bound EOF drain wait
1 parent aa6a7e6 commit 1380ede

3 files changed

Lines changed: 87 additions & 2 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ async def main():
6969

7070
logger = logging.getLogger(__name__)
7171

72+
DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS = 1.0
73+
7274
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7375

7476
_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel)
@@ -410,6 +412,9 @@ async def run(
410412
# to drain their responses via the still-open write stream (e.g. stdio
411413
# with bash-redirected stdin).
412414
drain_on_read_close: bool = False,
415+
# Maximum time to wait for in-flight handlers to drain after read EOF.
416+
# None means wait indefinitely.
417+
read_eof_drain_timeout_seconds: float | None = DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS,
413418
) -> None:
414419
async with self.lifespan(self) as lifespan_context:
415420
dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
@@ -421,6 +426,7 @@ async def run(
421426
# the initialized state instead of failing the init-gate.
422427
inline_methods=frozenset({"initialize"}),
423428
close_write_stream_on_read_close=not drain_on_read_close,
429+
read_eof_drain_timeout_seconds=read_eof_drain_timeout_seconds,
424430
)
425431
runner = ServerRunner(
426432
server=self,

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def __init__(
228228
raise_handler_exceptions: bool = False,
229229
inline_methods: frozenset[str] = frozenset(),
230230
close_write_stream_on_read_close: bool = True,
231+
read_eof_drain_timeout_seconds: float | None = None,
231232
) -> None: ...
232233
@overload
233234
def __init__(
@@ -240,6 +241,7 @@ def __init__(
240241
raise_handler_exceptions: bool = False,
241242
inline_methods: frozenset[str] = frozenset(),
242243
close_write_stream_on_read_close: bool = True,
244+
read_eof_drain_timeout_seconds: float | None = None,
243245
) -> None: ...
244246
def __init__(
245247
self,
@@ -251,6 +253,7 @@ def __init__(
251253
raise_handler_exceptions: bool = False,
252254
inline_methods: frozenset[str] = frozenset(),
253255
close_write_stream_on_read_close: bool = True,
256+
read_eof_drain_timeout_seconds: float | None = None,
254257
) -> None:
255258
self._read_stream = read_stream
256259
self._write_stream = write_stream
@@ -264,6 +267,7 @@ def __init__(
264267
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
265268
self._raise_handler_exceptions = raise_handler_exceptions
266269
self._close_write_stream_on_read_close = close_write_stream_on_read_close
270+
self._read_eof_drain_timeout_seconds = read_eof_drain_timeout_seconds
267271
# Request methods handled inline in the read loop (awaited before the
268272
# next message is dequeued) instead of spawned concurrently. Use for
269273
# methods whose side effects must be observable to the next message,
@@ -436,17 +440,22 @@ async def run(
436440
self._fan_out_closed()
437441
normal_eof = True
438442
finally:
439-
if not normal_eof:
443+
if not normal_eof or self._close_write_stream_on_read_close:
440444
# Transport closed abnormally: cancel in-flight handlers.
441445
# On normal EOF, let already-received handlers drain
442446
# their responses before the task group exits.
443447
tg.cancel_scope.cancel()
448+
elif self._read_eof_drain_timeout_seconds is not None:
449+
tg.cancel_scope.deadline = anyio.current_time() + self._read_eof_drain_timeout_seconds
444450
finally:
445451
# Covers the cancel/crash paths where the inline fan-out above is
446452
# never reached. Idempotent.
447453
self._running = False
448454
self._tg = None
449455
self._fan_out_closed()
456+
if not self._close_write_stream_on_read_close:
457+
with anyio.CancelScope(shield=True):
458+
await self._write_stream.aclose()
450459

451460
async def _dispatch(
452461
self,

tests/server/test_cancel_handling.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,13 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
120120
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
121121

122122
async def run_server():
123-
await server.run(server_read, server_write, server.create_initialization_options(), drain_on_read_close=True)
123+
await server.run(
124+
server_read,
125+
server_write,
126+
server.create_initialization_options(),
127+
drain_on_read_close=True,
128+
read_eof_drain_timeout_seconds=None,
129+
)
124130
server_run_returned.set()
125131

126132
init_req = JSONRPCRequest(
@@ -166,6 +172,70 @@ async def run_server():
166172
await server_run_returned.wait()
167173

168174

175+
@pytest.mark.anyio
176+
async def test_server_bounds_drain_on_read_eof_when_handler_never_finishes():
177+
handler_started = anyio.Event()
178+
handler_cancelled = anyio.Event()
179+
server_run_returned = anyio.Event()
180+
181+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
182+
handler_started.set()
183+
try:
184+
await anyio.sleep_forever()
185+
finally:
186+
handler_cancelled.set()
187+
raise AssertionError # pragma: no cover
188+
189+
server = Server("test", on_call_tool=handle_call_tool)
190+
191+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
192+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
193+
194+
async def run_server():
195+
await server.run(
196+
server_read,
197+
server_write,
198+
server.create_initialization_options(),
199+
drain_on_read_close=True,
200+
read_eof_drain_timeout_seconds=0.05,
201+
)
202+
server_run_returned.set()
203+
204+
init_req = JSONRPCRequest(
205+
jsonrpc="2.0",
206+
id=1,
207+
method="initialize",
208+
params=InitializeRequestParams(
209+
protocol_version=LATEST_PROTOCOL_VERSION,
210+
capabilities=ClientCapabilities(),
211+
client_info=Implementation(name="test", version="1.0"),
212+
).model_dump(by_alias=True, mode="json", exclude_none=True),
213+
)
214+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
215+
call_req = JSONRPCRequest(
216+
jsonrpc="2.0",
217+
id=2,
218+
method="tools/call",
219+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
220+
)
221+
222+
with anyio.fail_after(2):
223+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
224+
tg.start_soon(run_server)
225+
226+
await to_server.send(SessionMessage(init_req))
227+
await from_server.receive() # init response
228+
await to_server.send(SessionMessage(initialized))
229+
await to_server.send(SessionMessage(call_req))
230+
231+
await handler_started.wait()
232+
await to_server.aclose()
233+
234+
await server_run_returned.wait()
235+
236+
assert handler_cancelled.is_set()
237+
238+
169239
@pytest.mark.anyio
170240
async def test_server_reraises_handler_cancellation_when_server_is_cancelled():
171241
"""If the server task is cancelled (e.g. KeyboardInterrupt), in-flight

0 commit comments

Comments
 (0)