Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .server.session import ServerSession
from .server.stdio import stdio_server
from .shared.exceptions import McpError, UrlElicitationRequiredError
from .shared.session import MessageMiddleware
from .types import (
CallToolRequest,
ClientCapabilities,
Expand All @@ -23,6 +24,8 @@
InitializeRequest,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListPromptsRequest,
Expand Down Expand Up @@ -87,8 +90,11 @@
"InitializeResult",
"InitializedNotification",
"JSONRPCError",
"JSONRPCMessage",
"JSONRPCNotification",
"JSONRPCRequest",
"JSONRPCResponse",
"MessageMiddleware",
"ListPromptsRequest",
"ListPromptsResult",
"ListResourcesRequest",
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.session import BaseSession, MessageMiddleware, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
Expand Down Expand Up @@ -123,13 +123,17 @@ def __init__(
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
send_middleware: list["MessageMiddleware"] | None = None,
receive_middleware: list["MessageMiddleware"] | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
send_middleware=send_middleware,
receive_middleware=receive_middleware,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
Expand Down
13 changes: 12 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
MessageMiddleware,
RequestResponder,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
Expand Down Expand Up @@ -91,8 +92,18 @@ def __init__(
write_stream: MemoryObjectSendStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
*,
send_middleware: list["MessageMiddleware"] | None = None,
receive_middleware: list["MessageMiddleware"] | None = None,
) -> None:
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
super().__init__(
read_stream,
write_stream,
types.ClientRequest,
types.ClientNotification,
send_middleware=send_middleware,
receive_middleware=receive_middleware,
)
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
)
Expand Down
50 changes: 44 additions & 6 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
Expand Down Expand Up @@ -43,6 +43,10 @@

RequestId = str | int

# Middleware type for transforming messages before sending or after receiving.
# Can be sync (returns JSONRPCMessage) or async (returns Awaitable[JSONRPCMessage]).
MessageMiddleware = Callable[[JSONRPCMessage], JSONRPCMessage | Awaitable[JSONRPCMessage]]


class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""
Expand Down Expand Up @@ -190,6 +194,9 @@ def __init__(
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
*,
send_middleware: list[MessageMiddleware] | None = None,
receive_middleware: list[MessageMiddleware] | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
Expand All @@ -202,6 +209,22 @@ def __init__(
self._progress_callbacks = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()
self._send_middleware = send_middleware or []
self._receive_middleware = receive_middleware or []

async def _apply_middleware(
self, message: JSONRPCMessage, middleware_list: list[MessageMiddleware]
) -> JSONRPCMessage:
"""Apply a list of middleware functions to a message."""
import inspect
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to top


for middleware in middleware_list:
result = middleware(message)
if inspect.isawaitable(result):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should inspect the function first instead of doing it on every single message

message = await result
else:
message = result # type: ignore[assignment]
return message

def add_response_router(self, router: ResponseRouter) -> None:
"""
Expand Down Expand Up @@ -278,7 +301,9 @@ async def send_request(
**request_data,
)

await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
message = JSONRPCMessage(jsonrpc_request)
message = await self._apply_middleware(message, self._send_middleware)
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))

# request read timeout takes precedence over session read timeout
timeout = None
Expand Down Expand Up @@ -328,24 +353,30 @@ async def send_notification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
message = JSONRPCMessage(jsonrpc_notification)
message = await self._apply_middleware(message, self._send_middleware)
session_message = SessionMessage( # pragma: no cover
message=JSONRPCMessage(jsonrpc_notification),
message=message,
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
await self._write_stream.send(session_message)

async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
message = JSONRPCMessage(jsonrpc_error)
message = await self._apply_middleware(message, self._send_middleware)
session_message = SessionMessage(message=message)
await self._write_stream.send(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
message = JSONRPCMessage(jsonrpc_response)
message = await self._apply_middleware(message, self._send_middleware)
session_message = SessionMessage(message=message)
await self._write_stream.send(session_message)

async def _receive_loop(self) -> None:
Expand All @@ -357,7 +388,14 @@ async def _receive_loop(self) -> None:
async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
continue

# Apply receive middleware to transform the message
if self._receive_middleware:
transformed_msg = await self._apply_middleware(message.message, self._receive_middleware)
message = SessionMessage(message=transformed_msg, metadata=message.metadata) # noqa: PLW2901

if isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
Expand Down
141 changes: 141 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,144 @@ async def mock_server():
await session.initialize()

await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)


@pytest.mark.anyio
async def test_client_session_send_middleware():
"""Test that send middleware can transform outgoing messages."""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

received_request = None
middleware_called = False

def add_custom_field(message: JSONRPCMessage) -> JSONRPCMessage:
"""Middleware that adds a custom field to initialize request params."""
nonlocal middleware_called
middleware_called = True

if isinstance(message.root, JSONRPCRequest):
# Add custom extension to the capabilities
data = message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
if data.get("method") == "initialize" and "params" in data:
if "capabilities" not in data["params"]:
data["params"]["capabilities"] = {}
# Add a custom extension field
data["params"]["capabilities"]["customExtension"] = {"enabled": True}
return JSONRPCMessage(JSONRPCRequest(**data))
return message

async def mock_server():
nonlocal received_request

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
received_request = jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)

result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
send_middleware=[add_custom_field],
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()

# Verify middleware was called and transformed the request
assert middleware_called
assert received_request is not None
assert "params" in received_request
assert "capabilities" in received_request["params"]
assert "customExtension" in received_request["params"]["capabilities"]
assert received_request["params"]["capabilities"]["customExtension"] == {"enabled": True}


@pytest.mark.anyio
async def test_client_session_async_middleware():
"""Test that async middleware works correctly."""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

middleware_called = False

async def async_middleware(message: JSONRPCMessage) -> JSONRPCMessage:
"""Async middleware that just passes through."""
nonlocal middleware_called
middleware_called = True
# Simulate some async work
await anyio.sleep(0)
return message

async def mock_server():
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)

result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
await client_to_server_receive.receive()

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
send_middleware=[async_middleware],
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()

assert middleware_called
Loading