Skip to content

Commit 32565b3

Browse files
committed
Address review feedback: ExceptionGroup bug, reverse dependency, type widening
- Replace task group + ctx.run(tg.start_soon, ...) with direct await sender_ctx.run(handler, msg) to avoid ExceptionGroup wrapping that would prevent ClosedResourceError from being caught - Move ReadStream/WriteStream protocols to mcp.shared._stream_protocols so shared/server modules don't depend on client internals - Restore write stream type narrowing in MessageStream (SessionMessage only, not SessionMessage | Exception) - Remove unused T_Item TypeVar
1 parent 1b34f2e commit 32565b3

File tree

10 files changed

+63
-60
lines changed

10 files changed

+63
-60
lines changed

src/mcp/client/_transport.py

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,12 @@
33
from __future__ import annotations
44

55
from contextlib import AbstractAsyncContextManager
6-
from types import TracebackType
7-
from typing import Protocol, TypeVar, runtime_checkable
8-
9-
from typing_extensions import Self
6+
from typing import Protocol
107

8+
from mcp.shared._stream_protocols import ReadStream, WriteStream
119
from mcp.shared.message import SessionMessage
1210

13-
T_co = TypeVar("T_co", covariant=True)
14-
T_contra = TypeVar("T_contra", contravariant=True)
15-
16-
17-
@runtime_checkable
18-
class ReadStream(Protocol[T_co]): # pragma: no branch
19-
"""Protocol for reading items from a stream.
20-
21-
Both ``MemoryObjectReceiveStream`` and ``ContextReceiveStream`` satisfy
22-
this protocol. Consumers that need the sender's context should use
23-
``getattr(stream, 'last_context', None)``.
24-
"""
25-
26-
async def receive(self) -> T_co: ... # pragma: no branch
27-
async def aclose(self) -> None: ... # pragma: no branch
28-
def __aiter__(self) -> ReadStream[T_co]: ... # pragma: no branch
29-
async def __anext__(self) -> T_co: ... # pragma: no branch
30-
async def __aenter__(self) -> Self: ... # pragma: no branch
31-
async def __aexit__( # pragma: no branch
32-
self,
33-
exc_type: type[BaseException] | None,
34-
exc_val: BaseException | None,
35-
exc_tb: TracebackType | None,
36-
) -> bool | None: ...
37-
38-
39-
@runtime_checkable
40-
class WriteStream(Protocol[T_contra]): # pragma: no branch
41-
"""Protocol for writing items to a stream.
42-
43-
Both ``MemoryObjectSendStream`` and ``ContextSendStream`` satisfy
44-
this protocol.
45-
"""
46-
47-
async def send(self, item: T_contra, /) -> None: ... # pragma: no branch
48-
async def aclose(self) -> None: ... # pragma: no branch
49-
async def __aenter__(self) -> Self: ... # pragma: no branch
50-
async def __aexit__( # pragma: no branch
51-
self,
52-
exc_type: type[BaseException] | None,
53-
exc_val: BaseException | None,
54-
exc_tb: TracebackType | None,
55-
) -> bool | None: ...
56-
11+
__all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"]
5712

5813
TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]]
5914

src/mcp/client/sse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ async def _send_message(session_message: SessionMessage) -> None:
143143
async for session_message in write_stream_reader:
144144
sender_ctx = write_stream_reader.last_context
145145
if sender_ctx is not None:
146-
async with anyio.create_task_group() as tg:
147-
sender_ctx.run(tg.start_soon, _send_message, session_message)
146+
await sender_ctx.run(_send_message, session_message)
148147
else:
149148
await _send_message(session_message) # pragma: no cover
150149
except Exception: # pragma: lax no cover

src/mcp/client/streamable_http.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,7 @@ async def handle_request_async():
482482
async for session_message in write_stream_reader:
483483
sender_ctx = write_stream_reader.last_context
484484
if sender_ctx is not None:
485-
async with anyio.create_task_group() as tg_local:
486-
sender_ctx.run(tg_local.start_soon, _handle_message, session_message)
485+
await sender_ctx.run(_handle_message, session_message)
487486
else:
488487
await _handle_message(session_message) # pragma: no cover
489488

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ async def main():
5252
from typing_extensions import TypeVar
5353

5454
from mcp import types
55-
from mcp.client._transport import ReadStream, WriteStream
5655
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
5756
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
5857
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
@@ -66,6 +65,7 @@ async def main():
6665
from mcp.server.streamable_http import EventStore
6766
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
6867
from mcp.server.transport_security import TransportSecuritySettings
68+
from mcp.shared._stream_protocols import ReadStream, WriteStream
6969
from mcp.shared.exceptions import MCPError
7070
from mcp.shared.message import ServerMessageMetadata, SessionMessage
7171
from mcp.shared.session import RequestResponder

src/mcp/server/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
3737
from pydantic import AnyUrl, TypeAdapter
3838

3939
from mcp import types
40-
from mcp.client._transport import ReadStream, WriteStream
4140
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
4241
from mcp.server.models import InitializationOptions
4342
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
43+
from mcp.shared._stream_protocols import ReadStream, WriteStream
4444
from mcp.shared.exceptions import StatelessModeNotSupported
4545
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
4646
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY

src/mcp/server/streamable_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from starlette.responses import Response
2525
from starlette.types import Receive, Scope, Send
2626

27-
from mcp.client._transport import ReadStream, WriteStream
2827
from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings
2928
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
29+
from mcp.shared._stream_protocols import ReadStream, WriteStream
3030
from mcp.shared.message import ServerMessageMetadata, SessionMessage
3131
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
3232
from mcp.types import (

src/mcp/shared/_context_streams.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2121

2222
T = TypeVar("T")
23-
T_Item = TypeVar("T_Item")
2423

2524
# Internal payload carried through the underlying raw stream.
2625
_Envelope = tuple[contextvars.Context, T]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Stream protocols for MCP transports.
2+
3+
These are general-purpose protocols satisfied by both ``MemoryObjectSendStream``/
4+
``MemoryObjectReceiveStream`` and the context-aware wrappers in ``_context_streams``.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from types import TracebackType
10+
from typing import Protocol, TypeVar, runtime_checkable
11+
12+
from typing_extensions import Self
13+
14+
T_co = TypeVar("T_co", covariant=True)
15+
T_contra = TypeVar("T_contra", contravariant=True)
16+
17+
18+
@runtime_checkable
19+
class ReadStream(Protocol[T_co]): # pragma: no branch
20+
"""Protocol for reading items from a stream.
21+
22+
Consumers that need the sender's context should use
23+
``getattr(stream, 'last_context', None)``.
24+
"""
25+
26+
async def receive(self) -> T_co: ... # pragma: no branch
27+
async def aclose(self) -> None: ... # pragma: no branch
28+
def __aiter__(self) -> ReadStream[T_co]: ... # pragma: no branch
29+
async def __anext__(self) -> T_co: ... # pragma: no branch
30+
async def __aenter__(self) -> Self: ... # pragma: no branch
31+
async def __aexit__( # pragma: no branch
32+
self,
33+
exc_type: type[BaseException] | None,
34+
exc_val: BaseException | None,
35+
exc_tb: TracebackType | None,
36+
) -> bool | None: ...
37+
38+
39+
@runtime_checkable
40+
class WriteStream(Protocol[T_contra]): # pragma: no branch
41+
"""Protocol for writing items to a stream."""
42+
43+
async def send(self, item: T_contra, /) -> None: ... # pragma: no branch
44+
async def aclose(self) -> None: ... # pragma: no branch
45+
async def __aenter__(self) -> Self: ... # pragma: no branch
46+
async def __aexit__( # pragma: no branch
47+
self,
48+
exc_type: type[BaseException] | None,
49+
exc_val: BaseException | None,
50+
exc_tb: TracebackType | None,
51+
) -> bool | None: ...

src/mcp/shared/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams
99
from mcp.shared.message import SessionMessage
1010

11-
MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]]
11+
MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage]]
1212

1313

1414
@asynccontextmanager

src/mcp/shared/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pydantic import BaseModel, TypeAdapter
1313
from typing_extensions import Self
1414

15-
from mcp.client._transport import ReadStream, WriteStream
15+
from mcp.shared._stream_protocols import ReadStream, WriteStream
1616
from mcp.shared.exceptions import MCPError
1717
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
1818
from mcp.shared.response_router import ResponseRouter
@@ -417,8 +417,8 @@ async def _handle_session_message(message: SessionMessage) -> None:
417417

418418
sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None)
419419
if sender_ctx is not None:
420-
async with anyio.create_task_group() as tg:
421-
sender_ctx.run(tg.start_soon, _handle_session_message, message)
420+
coro = sender_ctx.run(_handle_session_message, message)
421+
await coro
422422
else:
423423
await _handle_session_message(message)
424424

0 commit comments

Comments
 (0)