@@ -66,6 +66,8 @@ internal static bool SupportsPrimingEvent(string? protocolVersion)
6666 private readonly ITransport _transport ;
6767 private readonly RequestHandlers _requestHandlers ;
6868 private readonly NotificationHandlers _notificationHandlers ;
69+ private readonly JsonRpcMessageFilter _incomingMessageFilter ;
70+ private readonly JsonRpcMessageFilter _outgoingMessageFilter ;
6971 private readonly long _sessionStartingTimestamp = Stopwatch . GetTimestamp ( ) ;
7072
7173 private readonly DistributedContextPropagator _propagator = DistributedContextPropagator . Current ;
@@ -95,13 +97,17 @@ internal static bool SupportsPrimingEvent(string? protocolVersion)
9597 /// <param name="endpointName">The name of the endpoint for logging and debug purposes.</param>
9698 /// <param name="requestHandlers">A collection of request handlers.</param>
9799 /// <param name="notificationHandlers">A collection of notification handlers.</param>
100+ /// <param name="incomingMessageFilter">A filter that wraps incoming message processing. Takes the next handler and returns a wrapped handler. If null, a passthrough filter is used.</param>
101+ /// <param name="outgoingMessageFilter">A filter that wraps outgoing message processing. Takes the next handler and returns a wrapped handler. If null, a passthrough filter is used.</param>
98102 /// <param name="logger">The logger.</param>
99103 public McpSessionHandler (
100104 bool isServer ,
101105 ITransport transport ,
102106 string endpointName ,
103107 RequestHandlers requestHandlers ,
104108 NotificationHandlers notificationHandlers ,
109+ JsonRpcMessageFilter ? incomingMessageFilter ,
110+ JsonRpcMessageFilter ? outgoingMessageFilter ,
105111 ILogger logger )
106112 {
107113 Throw . IfNull ( transport ) ;
@@ -120,7 +126,9 @@ public McpSessionHandler(
120126 EndpointName = endpointName ;
121127 _requestHandlers = requestHandlers ;
122128 _notificationHandlers = notificationHandlers ;
123- _logger = logger ?? NullLogger . Instance ;
129+ _incomingMessageFilter = incomingMessageFilter ?? ( next => next ) ;
130+ _outgoingMessageFilter = outgoingMessageFilter ?? ( next => next ) ;
131+ _logger = logger ;
124132 LogSessionCreated ( EndpointName , _sessionId , _transportKind ) ;
125133 }
126134
@@ -309,36 +317,16 @@ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken
309317 AddTags ( ref tags , activity , message , method , target ) ;
310318 }
311319
312- switch ( message )
320+ var filteredHandler = _incomingMessageFilter ( async ( msg , ct ) =>
313321 {
314- case JsonRpcRequest request :
315- LogRequestHandlerCalled ( EndpointName , request . Method ) ;
316- long requestStartingTimestamp = Stopwatch . GetTimestamp ( ) ;
317- try
318- {
319- var result = await HandleRequest ( request , cancellationToken ) . ConfigureAwait ( false ) ;
320- LogRequestHandlerCompleted ( EndpointName , request . Method , GetElapsed ( requestStartingTimestamp ) . TotalMilliseconds ) ;
321- AddResponseTags ( ref tags , activity , result , method ) ;
322- }
323- catch ( Exception ex )
324- {
325- LogRequestHandlerException ( EndpointName , request . Method , GetElapsed ( requestStartingTimestamp ) . TotalMilliseconds , ex ) ;
326- throw ;
327- }
328- break ;
329-
330- case JsonRpcNotification notification :
331- await HandleNotification ( notification , cancellationToken ) . ConfigureAwait ( false ) ;
332- break ;
333-
334- case JsonRpcMessageWithId messageWithId :
335- HandleMessageWithId ( message , messageWithId ) ;
336- break ;
322+ var result = await HandleMessageCoreAsync ( msg , ct ) . ConfigureAwait ( false ) ;
323+ if ( addTags && result is not null )
324+ {
325+ AddResponseTags ( ref tags , activity , result , method ) ;
326+ }
327+ } ) ;
337328
338- default :
339- LogEndpointHandlerUnexpectedMessageType ( EndpointName , message . GetType ( ) . Name ) ;
340- break ;
341- }
329+ await filteredHandler ( message , cancellationToken ) . ConfigureAwait ( false ) ;
342330 }
343331 catch ( Exception e ) when ( addTags )
344332 {
@@ -351,7 +339,40 @@ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken
351339 }
352340 }
353341
354- private async Task HandleNotification ( JsonRpcNotification notification , CancellationToken cancellationToken )
342+ private async Task < JsonNode ? > HandleMessageCoreAsync ( JsonRpcMessage message , CancellationToken cancellationToken )
343+ {
344+ switch ( message )
345+ {
346+ case JsonRpcRequest request :
347+ LogRequestHandlerCalled ( EndpointName , request . Method ) ;
348+ long requestStartingTimestamp = Stopwatch . GetTimestamp ( ) ;
349+ try
350+ {
351+ var result = await HandleRequestAsync ( request , cancellationToken ) . ConfigureAwait ( false ) ;
352+ LogRequestHandlerCompleted ( EndpointName , request . Method , GetElapsed ( requestStartingTimestamp ) . TotalMilliseconds ) ;
353+ return result ;
354+ }
355+ catch ( Exception ex )
356+ {
357+ LogRequestHandlerException ( EndpointName , request . Method , GetElapsed ( requestStartingTimestamp ) . TotalMilliseconds , ex ) ;
358+ throw ;
359+ }
360+
361+ case JsonRpcNotification notification :
362+ await HandleNotificationAsync ( notification , cancellationToken ) . ConfigureAwait ( false ) ;
363+ return null ;
364+
365+ case JsonRpcMessageWithId messageWithId :
366+ HandleMessageWithId ( message , messageWithId ) ;
367+ return null ;
368+
369+ default :
370+ LogEndpointHandlerUnexpectedMessageType ( EndpointName , message . GetType ( ) . Name ) ;
371+ return null ;
372+ }
373+ }
374+
375+ private async Task HandleNotificationAsync ( JsonRpcNotification notification , CancellationToken cancellationToken )
355376 {
356377 // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
357378 if ( notification . Method == NotificationMethods . CancelledNotification )
@@ -387,7 +408,7 @@ private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId me
387408 }
388409 }
389410
390- private async Task < JsonNode ? > HandleRequest ( JsonRpcRequest request , CancellationToken cancellationToken )
411+ private async Task < JsonNode ? > HandleRequestAsync ( JsonRpcRequest request , CancellationToken cancellationToken )
391412 {
392413 if ( ! _requestHandlers . TryGetValue ( request . Method , out var handler ) )
393414 {
@@ -586,26 +607,31 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
586607 AddTags ( ref tags , activity , message , method , target ) ;
587608 }
588609
589- if ( _logger . IsEnabled ( LogLevel . Trace ) )
610+ var filteredHandler = _outgoingMessageFilter ( async ( msg , ct ) =>
590611 {
591- LogSendingMessageSensitive ( EndpointName , JsonSerializer . Serialize ( message , McpJsonUtilities . JsonContext . Default . JsonRpcMessage ) ) ;
592- }
593- else
594- {
595- LogSendingMessage ( EndpointName ) ;
596- }
612+ if ( _logger . IsEnabled ( LogLevel . Trace ) )
613+ {
614+ LogSendingMessageSensitive ( EndpointName , JsonSerializer . Serialize ( msg , McpJsonUtilities . JsonContext . Default . JsonRpcMessage ) ) ;
615+ }
616+ else
617+ {
618+ LogSendingMessage ( EndpointName ) ;
619+ }
597620
598- await SendToRelatedTransportAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
621+ await SendToRelatedTransportAsync ( msg , ct ) . ConfigureAwait ( false ) ;
599622
600- // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
601- // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
602- // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
603- if ( message is JsonRpcNotification { Method : NotificationMethods . CancelledNotification } notification &&
604- GetCancelledNotificationParams ( notification . Params ) is CancelledNotificationParams cn &&
605- _pendingRequests . TryRemove ( cn . RequestId , out var tcs ) )
606- {
607- tcs . TrySetCanceled ( default ) ;
608- }
623+ // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
624+ // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
625+ // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
626+ if ( msg is JsonRpcNotification { Method : NotificationMethods . CancelledNotification } notification &&
627+ GetCancelledNotificationParams ( notification . Params ) is CancelledNotificationParams cn &&
628+ _pendingRequests . TryRemove ( cn . RequestId , out var tcs ) )
629+ {
630+ tcs . TrySetCanceled ( default ) ;
631+ }
632+ } ) ;
633+
634+ await filteredHandler ( message , cancellationToken ) . ConfigureAwait ( false ) ;
609635 }
610636 catch ( Exception ex ) when ( addTags )
611637 {
0 commit comments