Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ select = [
]
ignore = ["PERF203", "PLC0415", "PLR0402"]

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead."


[tool.ruff.lint.mccabe]
max-complexity = 24 # Default is 10

Expand Down
6 changes: 3 additions & 3 deletions src/mcp/client/experimental/task_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def build_capability(self) -> types.ClientTasksCapability | None:
def handles_request(request: types.ServerRequest) -> bool:
"""Check if this handler handles the given request type."""
return isinstance(
request.root,
request,
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
)

Expand All @@ -259,7 +259,7 @@ async def handle_request(
types.ClientResult | types.ErrorData
)

match responder.request.root:
match responder.request:
case types.GetTaskRequest(params=params):
response = await self.get_task(ctx, params)
client_response = client_response_type.validate_python(response)
Expand All @@ -281,7 +281,7 @@ async def handle_request(
await responder.respond(client_response)

case _: # pragma: no cover
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
raise ValueError(f"Unhandled request type: {type(responder.request)}")


# Backwards compatibility aliases
Expand Down
38 changes: 14 additions & 24 deletions src/mcp/client/experimental/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,13 @@ async def call_tool_as_task(
_meta = types.RequestParams.Meta(**meta)

return await self._session.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
task=types.TaskMetadata(ttl=ttl),
_meta=_meta,
),
)
types.CallToolRequest(
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
task=types.TaskMetadata(ttl=ttl),
_meta=_meta,
),
),
types.CreateTaskResult,
)
Expand All @@ -115,10 +113,8 @@ async def get_task(self, task_id: str) -> types.GetTaskResult:
GetTaskResult containing the task status and metadata
"""
return await self._session.send_request(
types.ClientRequest(
types.GetTaskRequest(
params=types.GetTaskRequestParams(task_id=task_id),
)
types.GetTaskRequest(
params=types.GetTaskRequestParams(task_id=task_id),
),
types.GetTaskResult,
)
Expand All @@ -142,10 +138,8 @@ async def get_task_result(
The task result, validated against result_type
"""
return await self._session.send_request(
types.ClientRequest(
types.GetTaskPayloadRequest(
params=types.GetTaskPayloadRequestParams(task_id=task_id),
)
types.GetTaskPayloadRequest(
params=types.GetTaskPayloadRequestParams(task_id=task_id),
),
result_type,
)
Expand All @@ -164,9 +158,7 @@ async def list_tasks(
"""
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
return await self._session.send_request(
types.ClientRequest(
types.ListTasksRequest(params=params),
),
types.ListTasksRequest(params=params),
types.ListTasksResult,
)

Expand All @@ -180,10 +172,8 @@ async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
CancelTaskResult with the updated task state
"""
return await self._session.send_request(
types.ClientRequest(
types.CancelTaskRequest(
params=types.CancelTaskRequestParams(task_id=task_id),
)
types.CancelTaskRequest(
params=types.CancelTaskRequestParams(task_id=task_id),
),
types.CancelTaskResult,
)
Expand Down
118 changes: 50 additions & 68 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,7 @@ def __init__(
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._sampling_capabilities = sampling_capabilities
Expand All @@ -143,6 +137,14 @@ def __init__(
# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()

@property
def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]:
return types.server_request_adapter

@property
def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]:
return types.server_notification_adapter

async def initialize(self) -> types.InitializeResult:
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
Expand All @@ -167,20 +169,18 @@ async def initialize(self) -> types.InitializeResult:
)

result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocol_version=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
tasks=self._task_handlers.build_capability(),
),
client_info=self._client_info,
types.InitializeRequest(
params=types.InitializeRequestParams(
protocol_version=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
tasks=self._task_handlers.build_capability(),
),
)
client_info=self._client_info,
),
),
types.InitializeResult,
)
Expand All @@ -190,7 +190,7 @@ async def initialize(self) -> types.InitializeResult:

self._server_capabilities = result.capabilities

await self.send_notification(types.ClientNotification(types.InitializedNotification()))
await self.send_notification(types.InitializedNotification())

return result

Expand Down Expand Up @@ -218,10 +218,7 @@ def experimental(self) -> ExperimentalClientFeatures:

async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
types.ClientRequest(types.PingRequest()),
types.EmptyResult,
)
return await self.send_request(types.PingRequest(), types.EmptyResult)

async def send_progress_notification(
self,
Expand All @@ -232,26 +229,20 @@ async def send_progress_notification(
) -> None:
"""Send a progress notification."""
await self.send_notification(
types.ClientNotification(
types.ProgressNotification(
params=types.ProgressNotificationParams(
progress_token=progress_token,
progress=progress,
total=total,
message=message,
),
types.ProgressNotification(
params=types.ProgressNotificationParams(
progress_token=progress_token,
progress=progress,
total=total,
message=message,
),
)
)

async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SetLevelRequest(
params=types.SetLevelRequestParams(level=level),
)
),
types.SetLevelRequest(params=types.SetLevelRequestParams(level=level)),
types.EmptyResult,
)

Expand All @@ -261,10 +252,7 @@ async def list_resources(self, *, params: types.PaginatedRequestParams | None =
Args:
params: Full pagination parameters including cursor and any future fields
"""
return await self.send_request(
types.ClientRequest(types.ListResourcesRequest(params=params)),
types.ListResourcesResult,
)
return await self.send_request(types.ListResourcesRequest(params=params), types.ListResourcesResult)

async def list_resource_templates(
self, *, params: types.PaginatedRequestParams | None = None
Expand All @@ -275,28 +263,28 @@ async def list_resource_templates(
params: Full pagination parameters including cursor and any future fields
"""
return await self.send_request(
types.ClientRequest(types.ListResourceTemplatesRequest(params=params)),
types.ListResourceTemplatesRequest(params=params),
types.ListResourceTemplatesResult,
)

async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
types.ClientRequest(types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri)))),
types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri))),
types.ReadResourceResult,
)

async def subscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri)))),
types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri))),
types.EmptyResult,
)

async def unsubscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri)))),
types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri))),
types.EmptyResult,
)

Expand All @@ -316,10 +304,8 @@ async def call_tool(
_meta = types.RequestParams.Meta(**meta)

result = await self.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
)
types.CallToolRequest(
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
Expand Down Expand Up @@ -364,17 +350,15 @@ async def list_prompts(self, *, params: types.PaginatedRequestParams | None = No
params: Full pagination parameters including cursor and any future fields
"""
return await self.send_request(
types.ClientRequest(types.ListPromptsRequest(params=params)),
types.ListPromptsRequest(params=params),
types.ListPromptsResult,
)

async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return await self.send_request(
types.ClientRequest(
types.GetPromptRequest(
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
types.GetPromptRequest(
params=types.GetPromptRequestParams(name=name, arguments=arguments),
),
types.GetPromptResult,
)
Expand All @@ -391,14 +375,12 @@ async def complete(
context = types.CompletionContext(arguments=context_arguments)

return await self.send_request(
types.ClientRequest(
types.CompleteRequest(
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
context=context,
),
)
types.CompleteRequest(
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
context=context,
),
),
types.CompleteResult,
)
Expand All @@ -410,7 +392,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None
params: Full pagination parameters including cursor and any future fields
"""
result = await self.send_request(
types.ClientRequest(types.ListToolsRequest(params=params)),
types.ListToolsRequest(params=params),
types.ListToolsResult,
)

Expand All @@ -423,7 +405,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None

async def send_roots_list_changed(self) -> None: # pragma: no cover
"""Send a roots/list_changed notification."""
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
await self.send_notification(types.RootsListChangedNotification())

async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
Expand All @@ -440,7 +422,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
return None

# Core request handling
match responder.request.root:
match responder.request:
case types.CreateMessageRequest(params=params):
with responder:
# Check if this is a task-augmented request
Expand Down Expand Up @@ -469,7 +451,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques

case types.PingRequest(): # pragma: no cover
with responder:
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
return await responder.respond(types.EmptyResult())

case _: # pragma: no cover
pass # Task requests handled above by _task_handlers
Expand All @@ -486,7 +468,7 @@ async def _handle_incoming(
async def _received_notification(self, notification: types.ServerNotification) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
match notification:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case types.ElicitCompleteNotification(params=params):
Expand Down
Loading