Skip to content

Commit ac209ab

Browse files
committed
Make ClientRequestContext a concrete class
RequestContext[ClientSession] was the only instantiation of the generic left in the tree (the server seat has ServerRequestContext), so the public ClientRequestContext alias becomes the real dataclass and the private mcp.shared._context module is deleted. request_id is now always populated: the client only builds a context for inbound requests, and ping is answered before any context exists.
1 parent d367b17 commit ac209ab

11 files changed

Lines changed: 75 additions & 107 deletions

File tree

docs/migration.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,11 +634,11 @@ server = Server("my-server", on_call_tool=handle_call_tool)
634634

635635
The `mcp.shared.context` module has been removed. `RequestContext` is now split into `ClientRequestContext` (in `mcp.client.context`) and `ServerRequestContext` (in `mcp.server.context`).
636636

637-
The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3.
637+
The split separates shared fields from server-specific fields. There is no shared `RequestContext` generic anymore — each concrete class fixes its session type.
638638

639639
**`RequestContext` changes:**
640640

641-
- Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]`
641+
- The `RequestContext[SessionT, LifespanContextT, RequestT]` generic no longer exists; use `ClientRequestContext` or `ServerRequestContext[LifespanContextT, RequestT]`
642642
- Server-specific fields (`lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context`
643643

644644
**Before (v1):**
@@ -1188,6 +1188,7 @@ Behavior changes:
11881188
- **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: a callback that raises pydantic's `ValidationError` is still answered with `INVALID_PARAMS` (`"Invalid request parameters"`, empty `data`) because the dispatcher cannot distinguish it from inbound-params validation — this conflation is pre-existing v1 behavior, and a revisit is pending.
11891189
- **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. `send_notification` before entry still works.
11901190
- **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** The hint was never serialized by any client transport in v1 or v2 — it exists for the server's streamable-HTTP stream routing. Progress and response correlation via `progressToken` and the request id is unaffected.
1191+
- **The private `mcp.shared._context.RequestContext` generic is deleted.** Client callbacks now receive the concrete `mcp.client.ClientRequestContext`, whose `request_id` is always populated (the client only builds a context for inbound requests). Annotations spelled `RequestContext[ClientSession]` become `ClientRequestContext`.
11911192

11921193
`mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing — it has been unreachable at runtime since the server-side swap. `RequestResponder.respond()` no longer exists.
11931194

src/mcp/client/context.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
"""Request context for MCP client handlers."""
22

3-
from mcp.client.session import ClientSession
4-
from mcp.shared._context import RequestContext
3+
from mcp.client.session import ClientRequestContext
54

6-
ClientRequestContext = RequestContext[ClientSession]
7-
"""Context for handling incoming requests in a client session.
8-
9-
This context is passed to client-side callbacks (sampling, elicitation, list_roots) when the server sends requests
10-
to the client.
11-
12-
Attributes:
13-
request_id: The unique identifier for this request.
14-
meta: Optional metadata associated with the request.
15-
session: The client session handling this request.
16-
"""
5+
__all__ = ["ClientRequestContext"]

src/mcp/client/session.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
from collections.abc import Mapping
5+
from dataclasses import dataclass
56
from types import TracebackType
67
from typing import Any, Protocol, cast, get_args
78

@@ -14,15 +15,14 @@
1415
from mcp import types
1516
from mcp.client._transport import ReadStream, WriteStream
1617
from mcp.shared._compat import resync_tracer
17-
from mcp.shared._context import RequestContext
1818
from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher
1919
from mcp.shared.exceptions import MCPError
2020
from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher
2121
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2222
from mcp.shared.session import ProgressFnT, RequestResponder
2323
from mcp.shared.transport_context import TransportContext
2424
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
25-
from mcp.types._types import RequestParamsMeta
25+
from mcp.types import RequestId, RequestParamsMeta
2626

2727
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
2828

@@ -31,25 +31,34 @@
3131
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
3232

3333

34+
@dataclass(kw_only=True)
35+
class ClientRequestContext:
36+
"""Context for a server-initiated request, passed to the sampling/elicitation/list-roots callbacks."""
37+
38+
session: ClientSession
39+
request_id: RequestId
40+
meta: RequestParamsMeta | None = None
41+
42+
3443
class SamplingFnT(Protocol):
3544
async def __call__(
3645
self,
37-
context: RequestContext[ClientSession],
46+
context: ClientRequestContext,
3847
params: types.CreateMessageRequestParams,
3948
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
4049

4150

4251
class ElicitationFnT(Protocol):
4352
async def __call__(
4453
self,
45-
context: RequestContext[ClientSession],
54+
context: ClientRequestContext,
4655
params: types.ElicitRequestParams,
4756
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
4857

4958

5059
class ListRootsFnT(Protocol):
5160
async def __call__(
52-
self, context: RequestContext[ClientSession]
61+
self, context: ClientRequestContext
5362
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
5463

5564

@@ -71,7 +80,7 @@ async def _default_message_handler(
7180

7281

7382
async def _default_sampling_callback(
74-
context: RequestContext[ClientSession],
83+
context: ClientRequestContext,
7584
params: types.CreateMessageRequestParams,
7685
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
7786
return types.ErrorData(
@@ -81,7 +90,7 @@ async def _default_sampling_callback(
8190

8291

8392
async def _default_elicitation_callback(
84-
context: RequestContext[ClientSession],
93+
context: ClientRequestContext,
8594
params: types.ElicitRequestParams,
8695
) -> types.ElicitResult | types.ErrorData:
8796
return types.ErrorData(
@@ -91,7 +100,7 @@ async def _default_elicitation_callback(
91100

92101

93102
async def _default_list_roots_callback(
94-
context: RequestContext[ClientSession],
103+
context: ClientRequestContext,
95104
) -> types.ListRootsResult | types.ErrorData:
96105
return types.ErrorData(
97106
code=types.INVALID_REQUEST,
@@ -496,19 +505,22 @@ async def _on_request(
496505
payload["params"] = dict(params)
497506
request = types.server_request_adapter.validate_python(payload, by_name=False)
498507

499-
ctx = RequestContext[ClientSession](
500-
request_id=dctx.request_id, meta=request.params.meta if request.params else None, session=self
501-
)
502508
response: types.ClientResult | types.ErrorData
503-
match request:
504-
case types.CreateMessageRequest(params=sampling_params):
505-
response = await self._sampling_callback(ctx, sampling_params)
506-
case types.ElicitRequest(params=elicit_params):
507-
response = await self._elicitation_callback(ctx, elicit_params)
508-
case types.ListRootsRequest():
509-
response = await self._list_roots_callback(ctx)
510-
case types.PingRequest(): # pragma: no branch
511-
response = types.EmptyResult()
509+
if isinstance(request, types.PingRequest):
510+
# Answered without a context: direct dispatch carries no request id.
511+
response = types.EmptyResult()
512+
else:
513+
assert dctx.request_id is not None # the callback-driving dispatchers always assign ids
514+
ctx = ClientRequestContext(
515+
session=self, request_id=dctx.request_id, meta=request.params.meta if request.params else None
516+
)
517+
match request:
518+
case types.CreateMessageRequest(params=sampling_params):
519+
response = await self._sampling_callback(ctx, sampling_params)
520+
case types.ElicitRequest(params=elicit_params):
521+
response = await self._elicitation_callback(ctx, elicit_params)
522+
case types.ListRootsRequest(): # pragma: no branch
523+
response = await self._list_roots_callback(ctx)
512524
client_response = ClientResponse.validate_python(response)
513525
if isinstance(client_response, types.ErrorData):
514526
raise MCPError.from_error_data(client_response)

src/mcp/shared/_context.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

tests/client/test_list_roots_callback.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from pydantic import FileUrl
33

44
from mcp import Client
5-
from mcp.client.session import ClientSession
5+
from mcp.client import ClientRequestContext
66
from mcp.server.mcpserver import Context, MCPServer
7-
from mcp.shared._context import RequestContext
87
from mcp.types import ListRootsResult, Root, TextContent
98

109

@@ -20,7 +19,7 @@ async def test_list_roots_callback():
2019
)
2120

2221
async def list_roots_callback(
23-
context: RequestContext[ClientSession],
22+
context: ClientRequestContext,
2423
) -> ListRootsResult:
2524
return callback_return
2625

tests/client/test_sampling_callback.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import pytest
22

33
from mcp import Client
4-
from mcp.client.session import ClientSession
4+
from mcp.client import ClientRequestContext
55
from mcp.server.mcpserver import Context, MCPServer
6-
from mcp.shared._context import RequestContext
76
from mcp.types import (
87
CreateMessageRequestParams,
98
CreateMessageResult,
@@ -26,7 +25,7 @@ async def test_sampling_callback():
2625
)
2726

2827
async def sampling_callback(
29-
context: RequestContext[ClientSession],
28+
context: ClientRequestContext,
3029
params: CreateMessageRequestParams,
3130
) -> CreateMessageResult:
3231
return callback_return
@@ -71,7 +70,7 @@ async def test_create_message_backwards_compat_single_content():
7170
)
7271

7372
async def sampling_callback(
74-
context: RequestContext[ClientSession],
73+
context: ClientRequestContext,
7574
params: CreateMessageRequestParams,
7675
) -> CreateMessageResult:
7776
return callback_return

tests/client/test_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import pytest
1111

1212
from mcp import types
13+
from mcp.client import ClientRequestContext
1314
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
14-
from mcp.shared._context import RequestContext
1515
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
1616
from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest
1717
from mcp.shared.message import SessionMessage
@@ -425,7 +425,7 @@ async def test_client_capabilities_with_custom_callbacks():
425425
received_capabilities = None
426426

427427
async def custom_sampling_callback( # pragma: no cover
428-
context: RequestContext[ClientSession],
428+
context: ClientRequestContext,
429429
params: types.CreateMessageRequestParams,
430430
) -> types.CreateMessageResult | types.ErrorData:
431431
return types.CreateMessageResult(
@@ -435,7 +435,7 @@ async def custom_sampling_callback( # pragma: no cover
435435
)
436436

437437
async def custom_list_roots_callback( # pragma: no cover
438-
context: RequestContext[ClientSession],
438+
context: ClientRequestContext,
439439
) -> types.ListRootsResult | types.ErrorData:
440440
return types.ListRootsResult(roots=[])
441441

@@ -509,7 +509,7 @@ async def test_client_capabilities_with_sampling_tools():
509509
received_capabilities = None
510510

511511
async def custom_sampling_callback( # pragma: no cover
512-
context: RequestContext[ClientSession],
512+
context: ClientRequestContext,
513513
params: types.CreateMessageRequestParams,
514514
) -> types.CreateMessageResult | types.ErrorData:
515515
return types.CreateMessageResult(

tests/server/mcpserver/test_elicitation.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from pydantic import BaseModel, Field
77

88
from mcp import Client, types
9-
from mcp.client.session import ClientSession, ElicitationFnT
9+
from mcp.client import ClientRequestContext
10+
from mcp.client.session import ElicitationFnT
1011
from mcp.server.mcpserver import Context, MCPServer
11-
from mcp.shared._context import RequestContext
1212
from mcp.types import ElicitRequestParams, ElicitResult, TextContent
1313

1414

@@ -64,7 +64,7 @@ async def test_elicitation_accept_returns_the_users_answer_to_the_tool():
6464
create_ask_user_tool(mcp)
6565

6666
# Create a custom handler for elicitation requests
67-
async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
67+
async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams):
6868
if params.message == "Tool wants to ask: What is your name?":
6969
return ElicitResult(action="accept", content={"answer": "Test User"})
7070
else: # pragma: no cover
@@ -81,7 +81,7 @@ async def test_elicitation_decline_reaches_the_tool_without_content():
8181
mcp = MCPServer(name="ElicitationDeclineServer")
8282
create_ask_user_tool(mcp)
8383

84-
async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
84+
async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams):
8585
return ElicitResult(action="decline")
8686

8787
await call_tool_and_assert(
@@ -119,9 +119,7 @@ class InvalidNestedSchema(BaseModel):
119119
create_validation_tool("nested_model", InvalidNestedSchema)
120120

121121
# Dummy callback (won't be called due to validation failure)
122-
async def elicitation_callback(
123-
context: RequestContext[ClientSession], params: ElicitRequestParams
124-
): # pragma: no cover
122+
async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # pragma: no cover
125123
return ElicitResult(action="accept", content={})
126124

127125
async with Client(mcp, elicitation_callback=elicitation_callback) as client:
@@ -176,7 +174,7 @@ async def optional_tool(ctx: Context) -> str:
176174

177175
for content, expected in test_cases:
178176

179-
async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
177+
async def callback(context: ClientRequestContext, params: ElicitRequestParams):
180178
return ElicitResult(action="accept", content=content)
181179

182180
await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected)
@@ -194,9 +192,7 @@ async def invalid_optional_tool(ctx: Context) -> str:
194192
except TypeError as e:
195193
return f"Validation failed: {str(e)}"
196194

197-
async def elicitation_callback(
198-
context: RequestContext[ClientSession], params: ElicitRequestParams
199-
): # pragma: no cover
195+
async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # pragma: no cover
200196
return ElicitResult(action="accept", content={})
201197

202198
await call_tool_and_assert(
@@ -219,7 +215,7 @@ async def valid_multiselect_tool(ctx: Context) -> str:
219215
return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}"
220216
return f"User {result.action}" # pragma: no cover
221217

222-
async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
218+
async def multiselect_callback(context: ClientRequestContext, params: ElicitRequestParams):
223219
if "Please provide tags" in params.message:
224220
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
225221
return ElicitResult(action="decline") # pragma: no cover
@@ -239,7 +235,7 @@ async def optional_multiselect_tool(ctx: Context) -> str:
239235
return f"Name: {result.data.name}, Tags: {tags_str}"
240236
return f"User {result.action}" # pragma: no cover
241237

242-
async def optional_multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
238+
async def optional_multiselect_callback(context: ClientRequestContext, params: ElicitRequestParams):
243239
if "Please provide optional tags" in params.message:
244240
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
245241
return ElicitResult(action="decline") # pragma: no cover
@@ -273,7 +269,7 @@ async def defaults_tool(ctx: Context) -> str:
273269
return f"User {result.action}"
274270

275271
# First verify that defaults are present in the JSON schema sent to clients
276-
async def callback_schema_verify(context: RequestContext[ClientSession], params: ElicitRequestParams):
272+
async def callback_schema_verify(context: ClientRequestContext, params: ElicitRequestParams):
277273
# Verify the schema includes defaults
278274
assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation"
279275
schema = params.requested_schema
@@ -295,7 +291,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession], params:
295291
)
296292

297293
# Test overriding defaults
298-
async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams):
294+
async def callback_override(context: ClientRequestContext, params: ElicitRequestParams):
299295
return ElicitResult(
300296
action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False}
301297
)
@@ -371,7 +367,7 @@ async def select_color_legacy(ctx: Context) -> str:
371367
return f"User: {result.data.user_name}, Color: {result.data.color}"
372368
return f"User {result.action}" # pragma: no cover
373369

374-
async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams):
370+
async def enum_callback(context: ClientRequestContext, params: ElicitRequestParams):
375371
if "colors" in params.message and "legacy" not in params.message:
376372
return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]})
377373
elif "color" in params.message:

0 commit comments

Comments
 (0)