Skip to content

Commit 8c1556b

Browse files
committed
test: remove new pragmas — exercise send_response via sampling, assert for send_message narrow
The E2E dispatcher test now triggers a server→client sampling request, so the client's response flows through spy.send_response. All five Dispatcher methods are now exercised in one round-trip. ServerSession.send_message: replace the if-not-isinstance-raise guard with an assert. Same type-narrowing for pyright; the assert line runs in every test; no coverage pragma needed.
1 parent 5ddf1dd commit 8c1556b

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

src/mcp/server/session.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,7 @@ async def send_message(self, message: SessionMessage) -> None:
683683
Args:
684684
message: The session message to send
685685
"""
686-
if not isinstance(self._dispatcher, JSONRPCDispatcher): # pragma: no cover
687-
raise TypeError("send_message requires the default JSON-RPC dispatcher")
686+
assert isinstance(self._dispatcher, JSONRPCDispatcher), "send_message requires the default JSON-RPC dispatcher"
688687
await self._dispatcher._write_stream.send(message) # type: ignore[reportPrivateUsage]
689688

690689
async def _handle_incoming(self, req: ServerRequestResponder) -> None:

tests/shared/test_dispatcher.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import pytest
99

1010
from mcp.client.session import ClientSession
11-
from mcp.server.mcpserver import MCPServer
11+
from mcp.server.mcpserver import Context, MCPServer
12+
from mcp.shared._context import RequestContext
1213
from mcp.shared.dispatcher import (
1314
JSONRPCDispatcher,
1415
OnErrorFn,
@@ -17,7 +18,14 @@
1718
)
1819
from mcp.shared.memory import create_client_server_memory_streams
1920
from mcp.shared.message import MessageMetadata
20-
from mcp.types import ErrorData, RequestId
21+
from mcp.types import (
22+
CreateMessageRequestParams,
23+
CreateMessageResult,
24+
ErrorData,
25+
RequestId,
26+
SamplingMessage,
27+
TextContent,
28+
)
2129

2230
pytestmark = pytest.mark.anyio
2331

@@ -35,6 +43,7 @@ def __init__(self, inner: JSONRPCDispatcher) -> None:
3543
self._inner = inner
3644
self.sent_requests: list[dict[str, Any]] = []
3745
self.sent_notifications: list[dict[str, Any]] = []
46+
self.sent_responses: list[dict[str, Any] | ErrorData] = []
3847

3948
def set_handlers(self, on_request: OnRequestFn, on_notification: OnNotificationFn, on_error: OnErrorFn) -> None:
4049
self._inner.set_handlers(on_request, on_notification, on_error)
@@ -59,16 +68,33 @@ async def send_notification(
5968
await self._inner.send_notification(notification, related_request_id)
6069

6170
async def send_response(self, request_id: RequestId, response: dict[str, Any] | ErrorData) -> None:
62-
await self._inner.send_response(request_id, response) # pragma: no cover
71+
self.sent_responses.append(response)
72+
await self._inner.send_response(request_id, response)
6373

6474

6575
async def test_client_session_accepts_custom_dispatcher():
66-
"""ClientSession round-trips through a custom dispatcher end-to-end."""
76+
"""ClientSession round-trips through a custom dispatcher end-to-end, including
77+
a server-initiated request (sampling) so all five dispatcher methods fire."""
6778
app = MCPServer("test")
6879

6980
@app.tool()
70-
def greet(name: str) -> str:
71-
return f"Hello, {name}!"
81+
async def ask(question: str, ctx: Context) -> str:
82+
answer = await ctx.session.create_message(
83+
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=question))],
84+
max_tokens=10,
85+
)
86+
assert isinstance(answer.content, TextContent)
87+
return answer.content.text
88+
89+
async def sampling_callback(
90+
context: RequestContext[ClientSession], params: CreateMessageRequestParams
91+
) -> CreateMessageResult:
92+
return CreateMessageResult(
93+
role="assistant",
94+
content=TextContent(type="text", text="42"),
95+
model="test",
96+
stop_reason="endTurn",
97+
)
7298

7399
async with create_client_server_memory_streams() as (client_streams, server_streams):
74100
client_read, client_write = client_streams
@@ -83,17 +109,20 @@ def greet(name: str) -> str:
83109
server = app._lowlevel_server # type: ignore[reportPrivateUsage]
84110
tg.start_soon(lambda: server.run(server_read, server_write, server.create_initialization_options()))
85111

86-
async with ClientSession(dispatcher=spy) as session:
112+
async with ClientSession(dispatcher=spy, sampling_callback=sampling_callback) as session:
87113
await session.initialize()
88-
result = await session.call_tool("greet", {"name": "world"})
89-
assert result.content[0].text == "Hello, world!" # type: ignore[union-attr]
114+
result = await session.call_tool("ask", {"question": "meaning of life?"})
115+
assert result.content[0].text == "42" # type: ignore[union-attr]
90116

91117
tg.cancel_scope.cancel()
92118

93-
# Initialize + call_tool + list_tools (output-schema refresh after the call).
119+
# initialize, tools/call (triggers sampling on the server), tools/list (schema refresh)
94120
assert [r["method"] for r in spy.sent_requests] == ["initialize", "tools/call", "tools/list"]
95-
# InitializedNotification.
96121
assert [n["method"] for n in spy.sent_notifications] == ["notifications/initialized"]
122+
# The server's sampling/createMessage request hit us; our response went back through the spy.
123+
assert len(spy.sent_responses) == 1
124+
response = spy.sent_responses[0]
125+
assert isinstance(response, dict) and response["model"] == "test"
97126

98127

99128
async def test_base_session_requires_streams_or_dispatcher():

0 commit comments

Comments
 (0)