Skip to content

Commit 421e65b

Browse files
committed
Synthesize request ids in DirectDispatcher
Server-initiated sampling/elicitation/roots requests over a ClientSession built with dispatcher=DirectDispatcher failed before the callback ran: the session requires a populated request id and direct dispatch carried none. DirectDispatcher now assigns per-instance monotonic ids to inbound requests (notifications keep None, which is how middleware distinguishes them). Adds a non-ping direct-dispatch test and bounds the indefinite awaits in the existing dispatcher= tests.
1 parent 48f2b01 commit 421e65b

4 files changed

Lines changed: 63 additions & 17 deletions

File tree

src/mcp/client/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ async def _on_request(
507507

508508
response: types.ClientResult | types.ErrorData
509509
if isinstance(request, types.PingRequest):
510-
# Answered without a context: direct dispatch carries no request id.
510+
# Answered without a context: ping has no callback that would need one.
511511
response = types.EmptyResult()
512512
else:
513513
assert dctx.request_id is not None # the callback-driving dispatchers always assign ids

src/mcp/shared/direct_dispatcher.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class _DirectDispatchContext:
5050
_back_request: _Request
5151
_back_notify: _Notify
5252
request_id: RequestId | None = None
53-
"""Always `None`: direct dispatch has no wire-level request id."""
53+
"""A dispatcher-synthesized id for requests; `None` for notifications."""
5454
message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework
5555
"""Always `None`: in-memory dispatch attaches no transport metadata."""
5656
_on_progress: ProgressFnT | None = None
@@ -91,6 +91,7 @@ def __init__(self, transport_ctx: TransportContext):
9191
self._peer: DirectDispatcher | None = None
9292
self._on_request: OnRequest | None = None
9393
self._on_notify: OnNotify | None = None
94+
self._next_id = 0
9495
self._ready = anyio.Event()
9596
self._closed = anyio.Event()
9697

@@ -128,13 +129,16 @@ async def run(
128129
def close(self) -> None:
129130
self._closed.set()
130131

131-
def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext:
132+
def _make_context(
133+
self, on_progress: ProgressFnT | None = None, request_id: RequestId | None = None
134+
) -> _DirectDispatchContext:
132135
assert self._peer is not None
133136
peer = self._peer
134137
return _DirectDispatchContext(
135138
transport=self._transport_ctx,
136139
_back_request=lambda m, p, o: peer._dispatch_request(m, p, o),
137140
_back_notify=lambda m, p: peer._dispatch_notify(m, p),
141+
request_id=request_id,
138142
_on_progress=on_progress,
139143
)
140144

@@ -147,7 +151,9 @@ async def _dispatch_request(
147151
await self._ready.wait()
148152
assert self._on_request is not None
149153
opts = opts or {}
150-
dctx = self._make_context(on_progress=opts.get("on_progress"))
154+
# Synthesize an id: the DispatchContext contract reserves None for notifications.
155+
self._next_id += 1
156+
dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id)
151157
try:
152158
with anyio.fail_after(opts.get("timeout")):
153159
try:

tests/client/test_session.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import anyio.abc
99
import anyio.streams.memory
1010
import pytest
11+
from pydantic import FileUrl
1112

1213
from mcp import types
1314
from mcp.client import ClientRequestContext
@@ -972,18 +973,54 @@ async def server_on_notify(
972973

973974
session = ClientSession(dispatcher=client_side)
974975
results: list[types.EmptyResult] = []
975-
async with anyio.create_task_group() as tg:
976-
await tg.start(server_side.run, server_on_request, server_on_notify)
977-
async with session:
978-
results.append(await session.send_ping(meta=None))
979-
# Server-to-client: direct dispatch delivers ping with no params member (no _meta injection).
980-
assert await server_side.send_raw_request("ping", None) == {}
981-
await session.send_notification(types.RootsListChangedNotification())
982-
server_side.close()
976+
with anyio.fail_after(5):
977+
async with anyio.create_task_group() as tg:
978+
await tg.start(server_side.run, server_on_request, server_on_notify)
979+
async with session:
980+
results.append(await session.send_ping(meta=None))
981+
# Server-to-client: direct dispatch delivers ping with no params member (no _meta injection).
982+
assert await server_side.send_raw_request("ping", None) == {}
983+
await session.send_notification(types.RootsListChangedNotification())
984+
server_side.close()
983985
assert results == [types.EmptyResult()]
984986
assert notified == ["notifications/roots/list_changed"]
985987

986988

989+
@pytest.mark.anyio
990+
async def test_direct_dispatch_roots_list_reaches_callback_with_synthesized_request_id():
991+
"""A server-initiated roots/list over dispatcher= reaches the registered callback and round-trips
992+
the result; the callback context carries an int request_id (SDK-defined: DirectDispatcher
993+
synthesizes ids)."""
994+
client_side, server_side = create_direct_dispatcher_pair()
995+
contexts: list[ClientRequestContext] = []
996+
997+
async def list_roots(context: ClientRequestContext) -> types.ListRootsResult:
998+
contexts.append(context)
999+
return types.ListRootsResult(roots=[types.Root(uri=FileUrl("file:///workspace"))])
1000+
1001+
async def server_on_request(
1002+
ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None
1003+
) -> dict[str, object]:
1004+
raise NotImplementedError
1005+
1006+
async def server_on_notify(
1007+
ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None
1008+
) -> None:
1009+
raise NotImplementedError
1010+
1011+
session = ClientSession(dispatcher=client_side, list_roots_callback=list_roots)
1012+
result: dict[str, Any] | None = None
1013+
with anyio.fail_after(5):
1014+
async with anyio.create_task_group() as tg:
1015+
await tg.start(server_side.run, server_on_request, server_on_notify)
1016+
async with session:
1017+
result = await server_side.send_raw_request("roots/list", None)
1018+
server_side.close()
1019+
assert result == {"roots": [{"uri": "file:///workspace"}]}
1020+
assert len(contexts) == 1
1021+
assert isinstance(contexts[0].request_id, int)
1022+
1023+
9871024
@pytest.mark.anyio
9881025
async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset():
9891026
"""`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids
@@ -1021,9 +1058,10 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
10211058
pass
10221059

10231060
dispatcher = RecordingDispatcher()
1024-
async with ClientSession(dispatcher=dispatcher) as session:
1025-
await session.initialize()
1026-
await session.send_ping()
1061+
with anyio.fail_after(5):
1062+
async with ClientSession(dispatcher=dispatcher) as session:
1063+
await session.initialize()
1064+
await session.send_ping()
10271065
opts_by_method = dict(dispatcher.calls)
10281066
assert opts_by_method["initialize"].get("cancel_on_abandon") is False
10291067
assert "cancel_on_abandon" not in opts_by_method["ping"]

tests/shared/test_dispatcher.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,15 @@ async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair
228228

229229
@pytest.mark.anyio
230230
async def test_ctx_request_id_exposes_inbound_id(pair_factory: PairFactory):
231-
"""JSON-RPC carries the wire id through; direct dispatch has none."""
231+
"""Every dispatcher assigns each inbound request a distinct int id; JSON-RPC carries
232+
the wire id through, DirectDispatcher synthesizes one (SDK-defined)."""
232233
async with running_pair(pair_factory) as (client, _server, _crec, srec):
233234
with anyio.fail_after(5):
234235
await client.send_raw_request("tools/call", None)
235236
await client.send_raw_request("tools/call", None)
236237
a, b = (ctx.request_id for ctx in srec.contexts)
237-
assert (a is None and b is None) or (isinstance(a, int) and isinstance(b, int) and a != b)
238+
assert isinstance(a, int) and isinstance(b, int)
239+
assert a != b
238240

239241

240242
@pytest.mark.anyio

0 commit comments

Comments
 (0)