Skip to content

Commit afc40d1

Browse files
authored
Route SendRequestAsync logic through outgoing message filters (#1465)
1 parent e700951 commit afc40d1

File tree

3 files changed

+393
-40
lines changed

3 files changed

+393
-40
lines changed

src/ModelContextProtocol.Core/McpSessionHandler.cs

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -576,15 +576,6 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
576576
AddTags(ref tags, activity, request, method, target);
577577
}
578578

579-
if (_logger.IsEnabled(LogLevel.Trace))
580-
{
581-
LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
582-
}
583-
else
584-
{
585-
LogSendingRequest(EndpointName, request.Method);
586-
}
587-
588579
await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false);
589580

590581
// Now that the request has been sent, register for cancellation. If we registered before,
@@ -671,29 +662,17 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
671662
AddTags(ref tags, activity, message, method, target);
672663
}
673664

674-
await _outgoingMessageFilter(async (msg, ct) =>
675-
{
676-
if (_logger.IsEnabled(LogLevel.Trace))
677-
{
678-
LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(msg, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
679-
}
680-
else
681-
{
682-
LogSendingMessage(EndpointName);
683-
}
684-
685-
await SendToRelatedTransportAsync(msg, ct).ConfigureAwait(false);
665+
await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false);
686666

687-
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
688-
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
689-
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
690-
if (msg is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
691-
GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn &&
692-
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
693-
{
694-
tcs.TrySetCanceled(default);
695-
}
696-
})(message, cancellationToken).ConfigureAwait(false);
667+
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
668+
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
669+
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
670+
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
671+
GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn &&
672+
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
673+
{
674+
tcs.TrySetCanceled(default);
675+
}
697676
}
698677
catch (Exception ex) when (addTags)
699678
{
@@ -710,7 +689,33 @@ await _outgoingMessageFilter(async (msg, ct) =>
710689
// Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in
711690
// the HTTP response body for the POST request containing the corresponding JSON-RPC request.
712691
private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken)
713-
=> (message.Context?.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken);
692+
=> _outgoingMessageFilter((msg, ct) =>
693+
{
694+
if (msg is JsonRpcRequest request)
695+
{
696+
if (_logger.IsEnabled(LogLevel.Trace))
697+
{
698+
LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(msg, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
699+
}
700+
else
701+
{
702+
LogSendingRequest(EndpointName, request.Method);
703+
}
704+
}
705+
else
706+
{
707+
if (_logger.IsEnabled(LogLevel.Trace))
708+
{
709+
LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(msg, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
710+
}
711+
else
712+
{
713+
LogSendingMessage(EndpointName);
714+
}
715+
}
716+
717+
return (msg.Context?.RelatedTransport ?? _transport).SendMessageAsync(msg, ct);
718+
})(message, cancellationToken);
714719

715720
private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams)
716721
{

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Diagnostics;
1010
using System.Net;
1111
using System.Security.Claims;
12+
using System.Text.Json.Nodes;
1213

1314
namespace ModelContextProtocol.AspNetCore.Tests;
1415

@@ -290,6 +291,200 @@ public async Task LongRunningToolCall_DoesNotTimeout_WhenNoEventStreamStore()
290291

291292
}
292293

294+
[Fact]
295+
public async Task IncomingFilter_SeesClientRequests()
296+
{
297+
var observedMethods = new List<string>();
298+
299+
Builder.Services.AddMcpServer()
300+
.WithHttpTransport(ConfigureStateless)
301+
.WithMessageFilters(filters => filters.AddIncomingFilter((next) => async (context, cancellationToken) =>
302+
{
303+
if (context.JsonRpcMessage is JsonRpcRequest request)
304+
{
305+
observedMethods.Add(request.Method);
306+
}
307+
308+
await next(context, cancellationToken);
309+
}))
310+
.WithTools<EchoHttpContextUserTools>();
311+
312+
Builder.Services.AddHttpContextAccessor();
313+
314+
await using var app = Builder.Build();
315+
app.MapMcp();
316+
await app.StartAsync(TestContext.Current.CancellationToken);
317+
318+
await using var client = await ConnectAsync();
319+
320+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
321+
await client.CallToolAsync("echo_with_user_name",
322+
new Dictionary<string, object?> { ["message"] = "hi" },
323+
cancellationToken: TestContext.Current.CancellationToken);
324+
325+
Assert.Contains(RequestMethods.Initialize, observedMethods);
326+
Assert.Contains(RequestMethods.ToolsList, observedMethods);
327+
Assert.Contains(RequestMethods.ToolsCall, observedMethods);
328+
}
329+
330+
[Fact]
331+
public async Task OutgoingFilter_SeesResponsesAndRequests()
332+
{
333+
Assert.SkipWhen(Stateless, "Server-originated requests are not supported in stateless mode.");
334+
335+
var observedMessageTypes = new List<string>();
336+
337+
Builder.Services.AddMcpServer()
338+
.WithHttpTransport(ConfigureStateless)
339+
.WithMessageFilters(filters => filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
340+
{
341+
var typeName = context.JsonRpcMessage switch
342+
{
343+
JsonRpcRequest request => $"request:{request.Method}",
344+
JsonRpcResponse r when r.Result is JsonObject obj && obj.ContainsKey("protocolVersion") => "initialize-response",
345+
JsonRpcResponse r when r.Result is JsonObject obj && obj.ContainsKey("tools") => "tools-list-response",
346+
JsonRpcResponse r when r.Result is JsonObject obj && obj.ContainsKey("content") => "tool-call-response",
347+
_ => null,
348+
};
349+
350+
if (typeName is not null)
351+
{
352+
observedMessageTypes.Add(typeName);
353+
}
354+
355+
await next(context, cancellationToken);
356+
}))
357+
.WithTools<ClaimsPrincipalTools>()
358+
.WithTools<SamplingRegressionTools>();
359+
360+
await using var app = Builder.Build();
361+
app.MapMcp();
362+
await app.StartAsync(TestContext.Current.CancellationToken);
363+
364+
var clientOptions = new McpClientOptions
365+
{
366+
Capabilities = new() { Sampling = new() },
367+
Handlers = new()
368+
{
369+
SamplingHandler = (_, _, _) => new(new CreateMessageResult
370+
{
371+
Content = [new TextContentBlock { Text = "sampled response" }],
372+
Model = "test-model",
373+
}),
374+
},
375+
};
376+
377+
await using var client = await ConnectAsync(clientOptions: clientOptions);
378+
379+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
380+
await client.CallToolAsync("echo_claims_principal",
381+
new Dictionary<string, object?> { ["message"] = "hi" },
382+
cancellationToken: TestContext.Current.CancellationToken);
383+
await client.CallToolAsync("sampling-tool",
384+
new Dictionary<string, object?> { ["prompt"] = "Hello" },
385+
cancellationToken: TestContext.Current.CancellationToken);
386+
387+
Assert.Contains("initialize-response", observedMessageTypes);
388+
Assert.Contains("tools-list-response", observedMessageTypes);
389+
Assert.Contains("tool-call-response", observedMessageTypes);
390+
Assert.Contains($"request:{RequestMethods.SamplingCreateMessage}", observedMessageTypes);
391+
}
392+
393+
[Fact]
394+
public async Task OutgoingFilter_MultipleFilters_ExecuteInOrder()
395+
{
396+
var executionOrder = new List<string>();
397+
398+
Builder.Services.AddMcpServer()
399+
.WithHttpTransport(ConfigureStateless)
400+
.WithMessageFilters(filters =>
401+
{
402+
filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
403+
{
404+
if (context.JsonRpcMessage is JsonRpcResponse r && r.Result is JsonObject obj && obj.ContainsKey("tools"))
405+
{
406+
executionOrder.Add("filter1-before");
407+
}
408+
409+
await next(context, cancellationToken);
410+
411+
if (context.JsonRpcMessage is JsonRpcResponse r2 && r2.Result is JsonObject obj2 && obj2.ContainsKey("tools"))
412+
{
413+
executionOrder.Add("filter1-after");
414+
}
415+
});
416+
417+
filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
418+
{
419+
if (context.JsonRpcMessage is JsonRpcResponse r && r.Result is JsonObject obj && obj.ContainsKey("tools"))
420+
{
421+
executionOrder.Add("filter2-before");
422+
}
423+
424+
await next(context, cancellationToken);
425+
426+
if (context.JsonRpcMessage is JsonRpcResponse r2 && r2.Result is JsonObject obj2 && obj2.ContainsKey("tools"))
427+
{
428+
executionOrder.Add("filter2-after");
429+
}
430+
});
431+
})
432+
.WithTools<ClaimsPrincipalTools>();
433+
434+
await using var app = Builder.Build();
435+
app.MapMcp();
436+
await app.StartAsync(TestContext.Current.CancellationToken);
437+
438+
await using var client = await ConnectAsync();
439+
440+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
441+
442+
Assert.Equal(["filter1-before", "filter2-before", "filter2-after", "filter1-after"], executionOrder);
443+
}
444+
445+
[Fact]
446+
public async Task OutgoingFilter_CanSendAdditionalMessages()
447+
{
448+
Builder.Services.AddMcpServer()
449+
.WithHttpTransport(ConfigureStateless)
450+
.WithMessageFilters(filters => filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
451+
{
452+
if (context.JsonRpcMessage is JsonRpcResponse response &&
453+
response.Result is JsonObject result && result.ContainsKey("tools"))
454+
{
455+
var extraNotification = new JsonRpcNotification
456+
{
457+
Method = "test/extra",
458+
Params = new JsonObject { ["message"] = "injected" },
459+
Context = new JsonRpcMessageContext { RelatedTransport = context.JsonRpcMessage.Context?.RelatedTransport },
460+
};
461+
462+
await next(new MessageContext(context.Server, extraNotification), cancellationToken);
463+
}
464+
465+
await next(context, cancellationToken);
466+
}))
467+
.WithTools<ClaimsPrincipalTools>();
468+
469+
await using var app = Builder.Build();
470+
app.MapMcp();
471+
await app.StartAsync(TestContext.Current.CancellationToken);
472+
473+
await using var client = await ConnectAsync();
474+
475+
var extraReceived = new TaskCompletionSource<string?>(TaskCreationOptions.RunContinuationsAsynchronously);
476+
await using var registration = client.RegisterNotificationHandler("test/extra", (notification, _) =>
477+
{
478+
extraReceived.TrySetResult(notification.Params?["message"]?.GetValue<string>());
479+
return default;
480+
});
481+
482+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
483+
484+
var extraMessage = await extraReceived.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken);
485+
Assert.Equal("injected", extraMessage);
486+
}
487+
293488
private ClaimsPrincipal CreateUser(string name)
294489
=> new(new ClaimsIdentity(
295490
[new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)],

0 commit comments

Comments
 (0)