Skip to content

Commit 4e761db

Browse files
committed
Add support for message filters
This commit adds support for server-side message filters that can intercept and process incoming and outgoing JSON-RPC messages. It also includes: - Message filter infrastructure (McpMessageFilter, McpMessageHandler) - MessageContext for accessing message metadata during filter execution - Extended RequestContext with message filter support - McpServerBuilderExtensions for registering message filters - Comprehensive tests for message filter functionality
1 parent 84dcfbc commit 4e761db

File tree

14 files changed

+867
-100
lines changed

14 files changed

+867
-100
lines changed

src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,6 @@ private async Task<string> GetAccessTokenAsync(HttpResponseMessage response, boo
260260
// Get auth server metadata
261261
var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken).ConfigureAwait(false);
262262

263-
// Store auth server metadata for future refresh operations
264-
_authServerMetadata = authServerMetadata;
265-
266263
// The existing access token must be invalid to have resulted in a 401 response, but refresh might still work.
267264
var resourceUri = GetRequiredResourceUri(protectedResourceMetadata);
268265

@@ -296,6 +293,9 @@ await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { R
296293
}
297294
}
298295

296+
// Store auth server metadata for future refresh operations
297+
_authServerMetadata = authServerMetadata;
298+
299299
// Perform the OAuth flow
300300
return await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false);
301301
}

src/ModelContextProtocol.Core/Client/McpClientImpl.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ internal McpClientImpl(ITransport transport, string endpointName, McpClientOptio
5858

5959
RegisterHandlers(options, notificationHandlers, requestHandlers);
6060

61-
_sessionHandler = new McpSessionHandler(isServer: false, transport, endpointName, requestHandlers, notificationHandlers, _logger);
61+
_sessionHandler = new McpSessionHandler(
62+
isServer: false,
63+
transport,
64+
endpointName,
65+
requestHandlers,
66+
notificationHandlers,
67+
incomingMessageFilter: null,
68+
outgoingMessageFilter: null,
69+
_logger);
6270
}
6371

6472
private void RegisterHandlers(McpClientOptions options, NotificationHandlers notificationHandlers, RequestHandlers requestHandlers)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using ModelContextProtocol.Protocol;
2+
3+
namespace ModelContextProtocol;
4+
5+
/// <summary>
6+
/// Represents a filter that wraps the processing of incoming JSON-RPC messages.
7+
/// </summary>
8+
/// <param name="next">The next handler in the pipeline.</param>
9+
/// <returns>A wrapped handler that processes messages and optionally delegates to the next handler.</returns>
10+
internal delegate Func<JsonRpcMessage, CancellationToken, Task> JsonRpcMessageFilter(Func<JsonRpcMessage, CancellationToken, Task> next);

src/ModelContextProtocol.Core/McpSessionHandler.cs

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ internal static bool SupportsPrimingEvent(string? protocolVersion)
6666
private readonly ITransport _transport;
6767
private readonly RequestHandlers _requestHandlers;
6868
private readonly NotificationHandlers _notificationHandlers;
69+
private readonly JsonRpcMessageFilter _incomingMessageFilter;
70+
private readonly JsonRpcMessageFilter _outgoingMessageFilter;
6971
private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp();
7072

7173
private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current;
@@ -95,13 +97,17 @@ internal static bool SupportsPrimingEvent(string? protocolVersion)
9597
/// <param name="endpointName">The name of the endpoint for logging and debug purposes.</param>
9698
/// <param name="requestHandlers">A collection of request handlers.</param>
9799
/// <param name="notificationHandlers">A collection of notification handlers.</param>
100+
/// <param name="incomingMessageFilter">A filter that wraps incoming message processing. Takes the next handler and returns a wrapped handler. If null, a passthrough filter is used.</param>
101+
/// <param name="outgoingMessageFilter">A filter that wraps outgoing message processing. Takes the next handler and returns a wrapped handler. If null, a passthrough filter is used.</param>
98102
/// <param name="logger">The logger.</param>
99103
public McpSessionHandler(
100104
bool isServer,
101105
ITransport transport,
102106
string endpointName,
103107
RequestHandlers requestHandlers,
104108
NotificationHandlers notificationHandlers,
109+
JsonRpcMessageFilter? incomingMessageFilter,
110+
JsonRpcMessageFilter? outgoingMessageFilter,
105111
ILogger logger)
106112
{
107113
Throw.IfNull(transport);
@@ -120,7 +126,9 @@ public McpSessionHandler(
120126
EndpointName = endpointName;
121127
_requestHandlers = requestHandlers;
122128
_notificationHandlers = notificationHandlers;
123-
_logger = logger ?? NullLogger.Instance;
129+
_incomingMessageFilter = incomingMessageFilter ?? (next => next);
130+
_outgoingMessageFilter = outgoingMessageFilter ?? (next => next);
131+
_logger = logger;
124132
LogSessionCreated(EndpointName, _sessionId, _transportKind);
125133
}
126134

@@ -309,36 +317,16 @@ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken
309317
AddTags(ref tags, activity, message, method, target);
310318
}
311319

312-
switch (message)
320+
var filteredHandler = _incomingMessageFilter(async (msg, ct) =>
313321
{
314-
case JsonRpcRequest request:
315-
LogRequestHandlerCalled(EndpointName, request.Method);
316-
long requestStartingTimestamp = Stopwatch.GetTimestamp();
317-
try
318-
{
319-
var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false);
320-
LogRequestHandlerCompleted(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds);
321-
AddResponseTags(ref tags, activity, result, method);
322-
}
323-
catch (Exception ex)
324-
{
325-
LogRequestHandlerException(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds, ex);
326-
throw;
327-
}
328-
break;
329-
330-
case JsonRpcNotification notification:
331-
await HandleNotification(notification, cancellationToken).ConfigureAwait(false);
332-
break;
333-
334-
case JsonRpcMessageWithId messageWithId:
335-
HandleMessageWithId(message, messageWithId);
336-
break;
322+
var result = await HandleMessageCoreAsync(msg, ct).ConfigureAwait(false);
323+
if (addTags && result is not null)
324+
{
325+
AddResponseTags(ref tags, activity, result, method);
326+
}
327+
});
337328

338-
default:
339-
LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name);
340-
break;
341-
}
329+
await filteredHandler(message, cancellationToken).ConfigureAwait(false);
342330
}
343331
catch (Exception e) when (addTags)
344332
{
@@ -351,7 +339,40 @@ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken
351339
}
352340
}
353341

354-
private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken)
342+
private async Task<JsonNode?> HandleMessageCoreAsync(JsonRpcMessage message, CancellationToken cancellationToken)
343+
{
344+
switch (message)
345+
{
346+
case JsonRpcRequest request:
347+
LogRequestHandlerCalled(EndpointName, request.Method);
348+
long requestStartingTimestamp = Stopwatch.GetTimestamp();
349+
try
350+
{
351+
var result = await HandleRequestAsync(request, cancellationToken).ConfigureAwait(false);
352+
LogRequestHandlerCompleted(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds);
353+
return result;
354+
}
355+
catch (Exception ex)
356+
{
357+
LogRequestHandlerException(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds, ex);
358+
throw;
359+
}
360+
361+
case JsonRpcNotification notification:
362+
await HandleNotificationAsync(notification, cancellationToken).ConfigureAwait(false);
363+
return null;
364+
365+
case JsonRpcMessageWithId messageWithId:
366+
HandleMessageWithId(message, messageWithId);
367+
return null;
368+
369+
default:
370+
LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name);
371+
return null;
372+
}
373+
}
374+
375+
private async Task HandleNotificationAsync(JsonRpcNotification notification, CancellationToken cancellationToken)
355376
{
356377
// Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
357378
if (notification.Method == NotificationMethods.CancelledNotification)
@@ -387,7 +408,7 @@ private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId me
387408
}
388409
}
389410

390-
private async Task<JsonNode?> HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken)
411+
private async Task<JsonNode?> HandleRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken)
391412
{
392413
if (!_requestHandlers.TryGetValue(request.Method, out var handler))
393414
{
@@ -586,26 +607,31 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
586607
AddTags(ref tags, activity, message, method, target);
587608
}
588609

589-
if (_logger.IsEnabled(LogLevel.Trace))
610+
var filteredHandler = _outgoingMessageFilter(async (msg, ct) =>
590611
{
591-
LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
592-
}
593-
else
594-
{
595-
LogSendingMessage(EndpointName);
596-
}
612+
if (_logger.IsEnabled(LogLevel.Trace))
613+
{
614+
LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(msg, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
615+
}
616+
else
617+
{
618+
LogSendingMessage(EndpointName);
619+
}
597620

598-
await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false);
621+
await SendToRelatedTransportAsync(msg, ct).ConfigureAwait(false);
599622

600-
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
601-
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
602-
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
603-
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
604-
GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn &&
605-
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
606-
{
607-
tcs.TrySetCanceled(default);
608-
}
623+
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
624+
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
625+
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
626+
if (msg is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
627+
GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn &&
628+
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
629+
{
630+
tcs.TrySetCanceled(default);
631+
}
632+
});
633+
634+
await filteredHandler(message, cancellationToken).ConfigureAwait(false);
609635
}
610636
catch (Exception ex) when (addTags)
611637
{
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
namespace ModelContextProtocol.Server;
2+
3+
/// <summary>
4+
/// Delegate type for applying filters to JSON-RPC messages.
5+
/// </summary>
6+
/// <param name="next">The next message handler in the pipeline.</param>
7+
/// <returns>The next message handler wrapped with the filter.</returns>
8+
/// <remarks>
9+
/// <para>
10+
/// Message filters allow you to intercept and process JSON-RPC messages before they reach
11+
/// their respective handlers (incoming) or before they are sent (outgoing). This is useful for implementing
12+
/// cross-cutting concerns that need to apply to all message types, such as logging, authentication, rate limiting,
13+
/// redaction, or request tracing.
14+
/// </para>
15+
/// <para>
16+
/// Filters are applied in the order they are registered, with the first registered filter being the outermost.
17+
/// Each filter receives the next handler in the pipeline and can choose to:
18+
/// <list type="bullet">
19+
/// <item><description>Call the next handler to continue processing (await next(context, cancellationToken))</description></item>
20+
/// <item><description>Skip the default handlers entirely by not calling next</description></item>
21+
/// <item><description>Perform operations before and/or after calling next</description></item>
22+
/// <item><description>Catch and handle exceptions from inner handlers</description></item>
23+
/// </list>
24+
/// </para>
25+
/// <para>
26+
/// For request-specific filters, use <see cref="McpRequestFilter{TParams, TResult}"/> instead.
27+
/// </para>
28+
/// </remarks>
29+
public delegate McpMessageHandler McpMessageFilter(
30+
McpMessageHandler next);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
namespace ModelContextProtocol.Server;
2+
3+
/// <summary>
4+
/// Delegate type for handling incoming JSON-RPC messages.
5+
/// </summary>
6+
/// <param name="context">The message context containing the JSON-RPC message and other metadata.</param>
7+
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
8+
/// <returns>A task representing the asynchronous operation.</returns>
9+
/// <remarks>
10+
/// <para>
11+
/// This delegate can handle any type of JSON-RPC message, including requests, notifications, responses, and errors.
12+
/// Use this for implementing cross-cutting concerns that need to intercept all message types,
13+
/// such as logging, authentication, rate limiting, or request tracing.
14+
/// </para>
15+
/// <para>
16+
/// For request-specific handling, use <see cref="McpRequestHandler{TParams, TResult}"/> instead.
17+
/// </para>
18+
/// </remarks>
19+
public delegate Task McpMessageHandler(
20+
MessageContext context,
21+
CancellationToken cancellationToken);

src/ModelContextProtocol.Core/Server/McpServerFilters.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,39 @@ namespace ModelContextProtocol.Server;
1111
/// </remarks>
1212
public sealed class McpServerFilters
1313
{
14+
/// <summary>
15+
/// Gets the filters for all incoming JSON-RPC messages.
16+
/// </summary>
17+
/// <remarks>
18+
/// <para>
19+
/// These filters intercept all incoming JSON-RPC messages before they are processed by the server,
20+
/// including requests, notifications, responses, and errors. The filters can perform logging,
21+
/// authentication, rate limiting, or other cross-cutting concerns that apply to all message types.
22+
/// </para>
23+
/// <para>
24+
/// Message filters are applied before request-specific filters. If a message filter does not call
25+
/// the next handler in the pipeline, the default handlers will not be executed.
26+
/// </para>
27+
/// </remarks>
28+
public List<McpMessageFilter> IncomingMessageFilters { get; } = [];
29+
30+
/// <summary>
31+
/// Gets the filters for all outgoing JSON-RPC messages.
32+
/// </summary>
33+
/// <remarks>
34+
/// <para>
35+
/// These filters intercept all outgoing JSON-RPC messages before they are sent to the client,
36+
/// including responses, notifications, and errors. The filters can perform logging,
37+
/// redaction, auditing, or other cross-cutting concerns that apply to all message types.
38+
/// </para>
39+
/// <para>
40+
/// If a message filter does not call the next handler in the pipeline, the message will not be sent.
41+
/// Filters may also call the next handler multiple times with different messages to emit additional
42+
/// server-to-client messages.
43+
/// </para>
44+
/// </remarks>
45+
public List<McpMessageFilter> OutgoingMessageFilters { get; } = [];
46+
1447
/// <summary>
1548
/// Gets the filters for the list-tools handler pipeline.
1649
/// </summary>

src/ModelContextProtocol.Core/Server/McpServerImpl.cs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,17 @@ void Register<TPrimitive>(McpServerPrimitiveCollection<TPrimitive>? collection,
118118
}
119119

120120
// And initialize the session.
121-
_sessionHandler = new McpSessionHandler(isServer: true, _sessionTransport, _endpointName!, _requestHandlers, _notificationHandlers, _logger);
121+
var incomingMessageFilter = BuildMessageFilterPipeline(options.Filters.IncomingMessageFilters);
122+
var outgoingMessageFilter = BuildMessageFilterPipeline(options.Filters.OutgoingMessageFilters);
123+
_sessionHandler = new McpSessionHandler(
124+
isServer: true,
125+
_sessionTransport,
126+
_endpointName!,
127+
_requestHandlers,
128+
_notificationHandlers,
129+
incomingMessageFilter,
130+
outgoingMessageFilter,
131+
_logger);
122132
}
123133

124134
/// <inheritdoc/>
@@ -875,6 +885,37 @@ private static McpRequestHandler<TParams, TResult> BuildFilterPipeline<TParams,
875885
return current;
876886
}
877887

888+
private JsonRpcMessageFilter BuildMessageFilterPipeline(List<McpMessageFilter> filters)
889+
{
890+
if (filters.Count == 0)
891+
{
892+
return next => next;
893+
}
894+
895+
return next =>
896+
{
897+
// Build the handler chain from the filters.
898+
// The innermost handler calls the provided 'next' delegate with the message from the context.
899+
McpMessageHandler baseHandler = async (context, cancellationToken) =>
900+
{
901+
await next(context.JsonRpcMessage, cancellationToken).ConfigureAwait(false);
902+
};
903+
904+
var current = baseHandler;
905+
for (int i = filters.Count - 1; i >= 0; i--)
906+
{
907+
current = filters[i](current);
908+
}
909+
910+
// Return the handler that creates a MessageContext and invokes the pipeline.
911+
return async (message, cancellationToken) =>
912+
{
913+
var context = new MessageContext(new DestinationBoundMcpServer(this, message.Context?.RelatedTransport), message);
914+
await current(context, cancellationToken).ConfigureAwait(false);
915+
};
916+
};
917+
}
918+
878919
private void UpdateEndpointNameWithClientInfo()
879920
{
880921
if (ClientInfo is null)

0 commit comments

Comments
 (0)