-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Expand file tree
/
Copy pathjsonrpc_dispatcher.py
More file actions
759 lines (684 loc) · 34.4 KB
/
Copy pathjsonrpc_dispatcher.py
File metadata and controls
759 lines (684 loc) · 34.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
"""JSON-RPC `Dispatcher` implementation.
Consumes the existing `SessionMessage`-based stream contract that all current
transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation,
the receive loop, per-request task isolation, cancellation/progress wiring, and
the single exception-to-wire boundary.
The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and
sees only `(ctx, method, params) -> dict`. Transports sit below and see only
`SessionMessage` reads/writes.
The dispatcher is *mostly* MCP-agnostic - methods/params are opaque strings and
dicts - but it intercepts `notifications/cancelled` and
`notifications/progress` because request correlation, cancellation and
progress are exactly the wiring this layer exists to provide. Those few wire
shapes are extracted with structural `match` patterns (no casts, no
`mcp.types` model coupling); a malformed payload simply fails to match and
the correlation is skipped.
"""
from __future__ import annotations
import contextvars
import logging
from collections.abc import Awaitable, Callable, Mapping
from contextlib import AsyncExitStack
from dataclasses import dataclass, field
from typing import Any, Generic, Literal, TypeVar, cast, overload
import anyio
import anyio.abc
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from opentelemetry.trace import SpanKind
from pydantic import ValidationError
from mcp.shared._otel import inject_trace_context, otel_span
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT
from mcp.shared.exceptions import MCPError, NoBackChannelError
from mcp.shared.message import (
ClientMessageMetadata,
MessageMetadata,
ServerMessageMetadata,
SessionMessage,
)
from mcp.shared.transport_context import TransportContext
from mcp.types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
INVALID_PARAMS,
REQUEST_CANCELLED,
REQUEST_TIMEOUT,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ProgressToken,
RequestId,
)
__all__ = ["JSONRPCDispatcher"]
logger = logging.getLogger(__name__)
TransportT = TypeVar("TransportT", bound=TransportContext)
PeerCancelMode = Literal["interrupt", "signal"]
"""How inbound `notifications/cancelled` is applied to a running handler.
`"interrupt"` (default) cancels the handler's scope. `"signal"` only sets
`ctx.cancel_requested` and lets the handler observe it cooperatively.
"""
def _coerce_id(request_id: RequestId) -> RequestId:
"""Coerce a string request ID to int when it's a valid int literal.
`_allocate_id` only ever produces `int` keys for `_pending`, but a peer
may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession`
both perform this coercion at lookup time so the response still correlates.
"""
if isinstance(request_id, str):
try:
return int(request_id)
except ValueError:
pass
return request_id
@dataclass(slots=True)
class _Pending:
"""An outbound request awaiting its response."""
send: MemoryObjectSendStream[dict[str, Any] | ErrorData]
receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData]
on_progress: ProgressFnT | None = None
@dataclass(slots=True)
class _InFlight(Generic[TransportT]):
"""An inbound request currently being handled."""
scope: anyio.CancelScope
dctx: _JSONRPCDispatchContext[TransportT]
@dataclass
class _JSONRPCDispatchContext(Generic[TransportT]):
"""Concrete `DispatchContext` produced for each inbound JSON-RPC message."""
transport: TransportT
_dispatcher: JSONRPCDispatcher[TransportT]
_request_id: RequestId | None
message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework
"""The transport-attached `SessionMessage.metadata` for this inbound message.
Carries `ServerMessageMetadata` (HTTP request, SSE stream-close callbacks)
that the server lifts onto its request context. `None` for transports
that attach nothing.
"""
_progress_token: ProgressToken | None = None
_closed: bool = False
cancel_requested: anyio.Event = field(default_factory=anyio.Event)
@property
def request_id(self) -> RequestId | None:
return self._request_id
@property
def can_send_request(self) -> bool:
return self.transport.can_send_request and not self._closed
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
if self._closed:
logger.debug("dropped %s: dispatch context closed", method)
return
await self._dispatcher.notify(method, params, _related_request_id=self._request_id)
async def send_raw_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
if not self.can_send_request:
raise NoBackChannelError(method)
return await self._dispatcher.send_raw_request(method, params, opts, _related_request_id=self._request_id)
async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
if self._progress_token is None:
return
params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress}
if total is not None:
params["total"] = total
if message is not None:
params["message"] = message
await self.notify("notifications/progress", params)
def close(self) -> None:
self._closed = True
def _default_transport_builder(_meta: MessageMetadata) -> TransportContext:
return TransportContext(kind="jsonrpc", can_send_request=True)
def _shielded_progress(fn: ProgressFnT) -> ProgressFnT:
"""Wrap a user progress callback so it can't crash the dispatcher.
The callback runs as a bare task in the dispatcher's task group; an
uncaught exception would cancel every sibling (the read loop and all
in-flight requests). Swallow and log instead, matching the previous
receive-loop's behavior.
"""
async def _wrapped(progress: float, total: float | None, message: str | None) -> None:
try:
await fn(progress, total, message)
except Exception:
logger.exception("progress callback raised")
return _wrapped
def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata:
"""Choose the `SessionMessage.metadata` for an outgoing request/notification.
`ServerMessageMetadata` tags a server-to-client message with the inbound
request it belongs to (so streamable-HTTP can route it onto that request's
SSE stream). `ClientMessageMetadata` carries resumption hints to the
client transport. `None` is the common case.
`SessionMessage.metadata` carries exactly one of these, so when
`related_request_id` is set it takes precedence and any resumption hints
in `opts` are dropped (with a debug log): requests made from a dispatch
context are routed onto the inbound request's stream, not resumed.
"""
if related_request_id is not None:
if opts and (opts.get("resumption_token") is not None or opts.get("on_resumption_token") is not None):
logger.debug(
"dropping resumption hints: related_request_id %r takes precedence on metadata", related_request_id
)
return ServerMessageMetadata(related_request_id=related_request_id)
if opts:
token = opts.get("resumption_token")
on_token = opts.get("on_resumption_token")
if token is not None or on_token is not None:
return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token)
return None
class JSONRPCDispatcher(Dispatcher[TransportT]):
"""`Dispatcher` over the existing `SessionMessage` stream contract.
Inherits the `Dispatcher` Protocol explicitly so pyright checks
conformance at the class definition rather than at first use.
"""
@overload
def __init__(
self: JSONRPCDispatcher[TransportContext],
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
*,
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
close_write_stream_on_read_close: bool = True,
read_eof_drain_timeout_seconds: float | None = None,
) -> None: ...
@overload
def __init__(
self,
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
*,
transport_builder: Callable[[MessageMetadata], TransportT],
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
close_write_stream_on_read_close: bool = True,
read_eof_drain_timeout_seconds: float | None = None,
) -> None: ...
def __init__(
self,
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
*,
transport_builder: Callable[[MessageMetadata], TransportT] | None = None,
peer_cancel_mode: PeerCancelMode = "interrupt",
raise_handler_exceptions: bool = False,
inline_methods: frozenset[str] = frozenset(),
close_write_stream_on_read_close: bool = True,
read_eof_drain_timeout_seconds: float | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
# The overloads guarantee that when `transport_builder` is omitted,
# `TransportT` is `TransportContext`, so the default is type-correct;
# pyright can't see across overloads, hence the cast.
self._transport_builder = cast(
"Callable[[MessageMetadata], TransportT]",
transport_builder or _default_transport_builder,
)
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
self._raise_handler_exceptions = raise_handler_exceptions
self._close_write_stream_on_read_close = close_write_stream_on_read_close
self._read_eof_drain_timeout_seconds = read_eof_drain_timeout_seconds
# Request methods handled inline in the read loop (awaited before the
# next message is dequeued) instead of spawned concurrently. Use for
# methods whose side effects must be observable to the next message,
# e.g. `initialize`, so a pipelined follow-up sees the initialized state.
# Only suitable for handlers that complete quickly, since inline handling
# blocks dequeuing; a handler that awaits the peer (`send_raw_request`)
# while inline will deadlock because the parked read loop cannot dequeue
# the response.
self._inline_methods = inline_methods
self._next_id = 0
self._pending: dict[RequestId, _Pending] = {}
self._in_flight: dict[RequestId, _InFlight[TransportT]] = {}
self._tg: anyio.abc.TaskGroup | None = None
self._running = False
async def send_raw_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
*,
_related_request_id: RequestId | None = None,
) -> dict[str, Any]:
"""Send a JSON-RPC request and await its response.
`_related_request_id` is set only by `_JSONRPCDispatchContext` when a
handler makes a server-to-client request mid-flight; it routes the
outgoing message onto the correct per-request SSE stream (SHTTP) via
`ServerMessageMetadata`. Top-level callers leave it `None`.
Raises:
MCPError: The peer responded with a JSON-RPC error; or
`REQUEST_TIMEOUT` if `opts["timeout"]` elapsed; or
`CONNECTION_CLOSED` if the dispatcher shut down while
awaiting the response.
RuntimeError: Called before `run()` has started or after it has
finished.
"""
if not self._running:
raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close")
opts = opts or {}
request_id = self._allocate_id()
out_params = dict(params) if params is not None else {}
out_meta = dict(out_params.get("_meta") or {})
on_progress = opts.get("on_progress")
if on_progress is not None:
# The caller wants progress updates. The spec mechanism is: include
# `_meta.progressToken` on the request; the peer echoes that token on
# any `notifications/progress` it sends. We use the request id as the
# token so the receive loop can find this `_Pending.on_progress` by
# `_pending[token]` without a second lookup table.
out_meta["progressToken"] = request_id
out_params["_meta"] = out_meta
# buffer=1: at most one outcome is ever delivered. A `WouldBlock` from
# `_resolve_pending`/`_fan_out_closed` means the waiter already has an
# outcome and dropping the late/redundant signal is correct. buffer=0
# is unsafe - there's a window between registering `_pending[id]` and
# parking in `receive()` where a close signal would be lost.
send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1)
pending = _Pending(send=send, receive=receive, on_progress=on_progress)
self._pending[request_id] = pending
metadata = _outbound_metadata(_related_request_id, opts)
target = out_params.get("name")
span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}"
# TODO(maxisbey): the otel span + inject below mirror
# BaseSession.send_request for parity. They belong in an outbound
# middleware (symmetric with otel_middleware on the inbound side) once
# that seam exists; the dispatcher should not own otel.
try:
with otel_span(
span_name,
kind=SpanKind.CLIENT,
attributes={"mcp.method.name": method, "jsonrpc.request.id": str(request_id)},
):
# Inject W3C trace context into _meta (SEP-414). With a no-op
# tracer this writes nothing, but `_meta` itself is still
# present on the wire (and the interaction suite pins that).
inject_trace_context(out_meta)
msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params)
await self._write(msg, metadata)
with anyio.fail_after(opts.get("timeout")):
outcome = await receive.receive()
except TimeoutError:
# Spec-recommended courtesy: tell the peer we've given up so it can
# stop work and free resources. v1's BaseSession.send_request does
# NOT do this; it's new behaviour.
await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id)
raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None
except anyio.get_cancelled_exc_class():
# Our caller's scope was cancelled. We're already inside a cancelled
# scope, so any bare `await` here re-raises immediately - shield to
# let the courtesy cancel notification go out before we propagate.
with anyio.CancelScope(shield=True):
await self._cancel_outbound(request_id, "caller cancelled", _related_request_id)
raise
finally:
# Always remove the waiter, even on cancel/timeout, so a late
# response from the peer (race) hits a closed stream and is dropped
# in `_dispatch` rather than leaking.
self._pending.pop(request_id, None)
send.close()
receive.close()
if isinstance(outcome, ErrorData):
raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data)
return outcome
async def notify(
self,
method: str,
params: Mapping[str, Any] | None,
*,
_related_request_id: RequestId | None = None,
) -> None:
# Leave `params` unset (not explicitly None) when there are none:
# transports serialize with `exclude_unset=True`, and an explicit None
# would survive as `"params": null`, which JSON-RPC 2.0 forbids and
# strict peers (e.g. the TypeScript SDK's zod schemas) reject.
if params is not None:
msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params))
else:
msg = JSONRPCNotification(jsonrpc="2.0", method=method)
await self._write(msg, _outbound_metadata(_related_request_id, None))
async def run(
self,
on_request: OnRequest,
on_notify: OnNotify,
*,
task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
) -> None:
"""Drive the receive loop until the read stream closes.
Each inbound request is handled in its own task in an internal task
group; `task_status.started()` fires once that group is open, so
`await tg.start(dispatcher.run, ...)` resumes when `send_raw_request`
is usable.
"""
normal_eof = False
try:
async with anyio.create_task_group() as tg:
self._tg = tg
self._running = True
task_status.started()
try:
async with AsyncExitStack() as stack:
await stack.enter_async_context(self._read_stream)
if self._close_write_stream_on_read_close:
await stack.enter_async_context(self._write_stream)
try:
async for item in self._read_stream:
# Duck-typed: `_context_streams.ContextReceiveStream`
# exposes `.last_context` (the sender's contextvars
# snapshot per message). Plain memory streams don't.
sender_ctx: contextvars.Context | None = getattr(
self._read_stream, "last_context", None
)
await self._dispatch(item, on_request, on_notify, sender_ctx)
except anyio.ClosedResourceError:
# The transport closed our receive end and we looped
# back to `__anext__` on the now-closed stream
# (stateless SHTTP teardown). Same as EOF.
logger.debug("read stream closed by transport; treating as EOF")
# Read stream EOF: wake any blocked `send_raw_request` waiters
# (callers outside this task group) with CONNECTION_CLOSED.
self._running = False
self._fan_out_closed()
normal_eof = True
finally:
if not normal_eof or self._close_write_stream_on_read_close:
# Transport closed abnormally: cancel in-flight handlers.
# On normal EOF, let already-received handlers drain
# their responses before the task group exits.
tg.cancel_scope.cancel()
elif self._read_eof_drain_timeout_seconds is not None:
tg.cancel_scope.deadline = anyio.current_time() + self._read_eof_drain_timeout_seconds
finally:
# Covers the cancel/crash paths where the inline fan-out above is
# never reached. Idempotent.
self._running = False
self._tg = None
self._fan_out_closed()
if not self._close_write_stream_on_read_close:
with anyio.CancelScope(shield=True):
await self._write_stream.aclose()
async def _dispatch(
self,
item: SessionMessage | Exception,
on_request: OnRequest,
on_notify: OnNotify,
sender_ctx: contextvars.Context | None,
) -> None:
"""Route one inbound item.
Everything here is `send_nowait` or `_spawn`; the only `await` is for
`inline_methods` requests, which deliberately block dequeuing until
handled. Any other `await` would let one slow message head-of-line
block the entire read loop.
"""
if isinstance(item, Exception):
logger.debug("transport yielded exception: %r", item)
return
metadata = item.metadata
msg = item.message
match msg:
case JSONRPCRequest():
await self._dispatch_request(msg, metadata, on_request, sender_ctx)
case JSONRPCNotification():
self._dispatch_notification(msg, metadata, on_notify, sender_ctx)
case JSONRPCResponse():
self._resolve_pending(msg.id, msg.result)
case JSONRPCError(): # pragma: no branch
# `id` may be None per JSON-RPC (parse error before id known).
# The match is exhaustive over JSONRPCMessage; the no-match arc
# on this final case is unreachable.
self._resolve_pending(msg.id, msg.error)
async def _dispatch_request(
self,
req: JSONRPCRequest,
metadata: MessageMetadata,
on_request: OnRequest,
sender_ctx: contextvars.Context | None,
) -> None:
progress_token: ProgressToken | None
match req.params:
# The bool guard matters: `int()` patterns match bool (a subclass),
# and `True == 1` would alias dict lookups to request id 1.
case {"_meta": {"progressToken": str() | int() as progress_token}} if not isinstance(progress_token, bool):
pass
case _:
progress_token = None
try:
transport_ctx = self._transport_builder(metadata)
except Exception:
# Containment boundary for the user-supplied builder: a raising
# builder must cost only this message, not the whole connection
# (the exception would otherwise escape into run()'s read loop).
logger.exception("transport_builder raised; rejecting request %r", req.id)
self._spawn(
self._write_error,
req.id,
ErrorData(code=INTERNAL_ERROR, message="transport context unavailable"),
sender_ctx=sender_ctx,
)
return
dctx = _JSONRPCDispatchContext(
transport=transport_ctx,
_dispatcher=self,
_request_id=req.id,
message_metadata=metadata,
_progress_token=progress_token,
)
scope = anyio.CancelScope()
# TODO(maxisbey): the spec puts request-id uniqueness on the sender;
# neither v1 nor the TS SDK guards a duplicate id here, so for now we
# blind-overwrite (parity). Revisit rejecting with INVALID_REQUEST.
# Coerced key so `notifications/cancelled` correlates regardless of
# whether the peer stringifies the id between request and cancel
# (`_dispatch_notification` coerces at lookup; responses still echo
# `req.id` verbatim).
self._in_flight[_coerce_id(req.id)] = _InFlight(scope=scope, dctx=dctx)
if req.method in self._inline_methods:
# Spawn (so `sender_ctx` applies, matching the concurrent path) but
# park the read loop until the handler returns; that's the inline
# ordering guarantee. Because the read loop is parked, a handler
# that awaits the peer here (e.g. `dctx.send_raw_request`) will
# deadlock: the response can never be dequeued.
done = anyio.Event()
async def _run_inline() -> None:
try:
await self._handle_request(req, dctx, scope, on_request)
finally:
done.set()
self._spawn(_run_inline, sender_ctx=sender_ctx)
await done.wait()
else:
self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx)
def _dispatch_notification(
self,
msg: JSONRPCNotification,
metadata: MessageMetadata,
on_notify: OnNotify,
sender_ctx: contextvars.Context | None,
) -> None:
"""Route one inbound notification.
`notifications/cancelled` and `notifications/progress` are intercepted
here because they correlate against JSON-RPC request IDs - the
`_in_flight` / `_pending` tables this layer owns - so no higher layer
can act on them. Both are still teed to `on_notify` afterwards, so
middleware and registered notification handlers observe every inbound
notification. See the module docstring for the design rationale.
"""
if msg.method == "notifications/cancelled":
match msg.params:
# The bool guards here and below matter: `int()` patterns match
# bool (a subclass), and `True == 1` would alias the dict lookup
# to the entry keyed by request id 1.
case {"requestId": str() | int() as rid} if (
not isinstance(rid, bool) and (in_flight := self._in_flight.get(_coerce_id(rid))) is not None
):
in_flight.dctx.cancel_requested.set()
if self._peer_cancel_mode == "interrupt":
in_flight.scope.cancel()
case _:
pass
# fall through: cancelled is also teed to on_notify so middleware
# and registered handlers can observe it (matches DirectDispatcher,
# which forwards every notification).
elif msg.method == "notifications/progress":
match msg.params:
case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if (
not isinstance(token, bool)
and not isinstance(progress, bool)
and (pending := self._pending.get(_coerce_id(token))) is not None
and pending.on_progress is not None
):
total = msg.params.get("total")
message = msg.params.get("message")
self._spawn(
_shielded_progress(pending.on_progress),
float(progress),
float(total) if isinstance(total, int | float) else None,
message if isinstance(message, str) else None,
sender_ctx=sender_ctx,
)
case _:
pass
# fall through: progress is also teed to on_notify
try:
transport_ctx = self._transport_builder(metadata)
except Exception:
# Same containment boundary as `_dispatch_request`: a raising
# builder drops this notification instead of killing the read loop.
logger.exception("transport_builder raised; dropping notification %r", msg.method)
return
dctx = _JSONRPCDispatchContext(
transport=transport_ctx, _dispatcher=self, _request_id=None, message_metadata=metadata
)
self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx)
def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None:
pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None
if pending is None:
logger.debug("dropping response for unknown/late request id %r", request_id)
return
try:
pending.send.send_nowait(outcome)
except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("waiter for request id %r already gone", request_id)
def _spawn(
self,
fn: Callable[..., Awaitable[Any]],
*args: object,
sender_ctx: contextvars.Context | None,
) -> None:
"""Schedule `fn(*args)` in the run() task group, propagating the sender's contextvars.
ASGI middleware (auth, OTel) sets contextvars on the request task that
wrote into the read stream. `Context.run(tg.start_soon, ...)` makes
the spawned handler inherit *that* context instead of the receive
loop's, so `auth_context_var` and OTel spans survive.
"""
assert self._tg is not None
if sender_ctx is not None:
sender_ctx.run(self._tg.start_soon, fn, *args)
else:
self._tg.start_soon(fn, *args)
def _fan_out_closed(self) -> None:
"""Wake every pending `send_raw_request` waiter with `CONNECTION_CLOSED`.
Synchronous (uses `send_nowait`) because it's called from `finally`
which may be inside a cancelled scope. Idempotent.
"""
closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed")
for pending in self._pending.values():
try:
pending.send.send_nowait(closed)
except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError):
pass
self._pending.clear()
async def _handle_request(
self,
req: JSONRPCRequest,
dctx: _JSONRPCDispatchContext[TransportT],
scope: anyio.CancelScope,
on_request: OnRequest,
) -> None:
"""Run `on_request` for one inbound request and write its response.
This is the single exception-to-wire boundary: handler exceptions are
caught here and serialized to `JSONRPCError`. Nothing above this in
the stack constructs wire errors.
"""
try:
with scope:
try:
result = await on_request(dctx, req.method, req.params)
finally:
# Handler done: close the back-channel (detached work that
# later calls `dctx.send_raw_request()` should see
# `NoBackChannelError`) and drop from `_in_flight` so a
# late `notifications/cancelled` is a no-op rather than
# racing the result write below. No checkpoint between
# handler return and the pop, so the cancel can't
# interleave there.
dctx.close()
self._in_flight.pop(_coerce_id(req.id), None)
await self._write_result(req.id, result)
if scope.cancel_called:
# Peer-cancel: `_dispatch_notification` cancelled this scope
# while the handler was running. anyio swallows a scope's *own*
# cancel at __exit__, so execution lands here rather than the
# `except cancelled` arm below.
# TODO(maxisbey): spec says SHOULD NOT respond after cancel.
# The existing server always has, so match that for now.
await self._write_error(req.id, ErrorData(code=0, message="Request cancelled"))
except anyio.get_cancelled_exc_class():
# Outer-cancel: run()'s task group is shutting down. Any bare
# `await` here re-raises immediately, so shield the courtesy write.
with anyio.CancelScope(shield=True):
await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled"))
raise
except MCPError as e:
await self._write_error(req.id, e.error)
except ValidationError:
# TODO(maxisbey): data="" is pinned compat with the existing
# server (which never leaked pydantic error text onto the wire).
# Consider putting the validation detail in `data` once the
# interaction suite's divergence entry is resolved.
await self._write_error(
req.id, ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")
)
except Exception as e:
logger.exception("handler for %r raised", req.method)
# TODO(maxisbey): code=0 is pinned compat with the existing
# server's `_handle_request`. JSON-RPC says INTERNAL_ERROR
# (-32603); revisit once the suite's divergence entry is resolved.
await self._write_error(req.id, ErrorData(code=0, message=str(e)))
if self._raise_handler_exceptions:
raise
# No outer `_in_flight` pop here: the inner `finally` above already
# removes the entry on every path out of the handler, and a second
# pop after the awaited response writes could evict a newer request
# that reused the id during that window.
def _allocate_id(self) -> int:
self._next_id += 1
return self._next_id
async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None:
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))
async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None:
try:
await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result))
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped result for %r: write stream closed", request_id)
async def _write_error(self, request_id: RequestId, error: ErrorData) -> None:
try:
await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error))
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped error for %r: write stream closed", request_id)
async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None:
# Thread `related_request_id` so streamable-HTTP routes the cancel onto
# the same per-request SSE stream as the request it cancels; without it
# the notification falls through to the standalone GET stream and is
# dropped when no GET stream is open.
try:
await self.notify(
"notifications/cancelled",
{"requestId": request_id, "reason": reason},
_related_request_id=related_request_id,
)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
pass