Skip to content

Commit 9745e92

Browse files
halter73Copilot
andcommitted
Route SendRequestAsync through outgoing filters and add message filter tests
Fix McpSessionHandler.SendRequestAsync to route through the outgoing message filter pipeline instead of calling SendToRelatedTransportAsync directly. This makes server-originated JSON-RPC requests (elicitation/create, sampling/createMessage, roots/list) visible to outgoing filters, matching the documented behavior that filters see all outgoing messages. Add AddOutgoingMessageFilter_Sees_ServerOriginatedRequests test to verify the fix catches the regression via stream transport. Add 5 message filter tests to MapMcpTests covering incoming filters, outgoing filters, server-originated requests, filter ordering, and additional message injection. These run across all HTTP transport variants (SSE, Streamable HTTP, and Stateless) via the existing MapMcpSseTests, MapMcpStreamableHttpTests, and MapMcpStatelessTests subclasses. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 51a4fde commit 9745e92

3 files changed

Lines changed: 279 additions & 6 deletions

File tree

src/ModelContextProtocol.Core/McpSessionHandler.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,10 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
585585
LogSendingRequest(EndpointName, request.Method);
586586
}
587587

588-
await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false);
588+
await _outgoingMessageFilter(async (msg, ct) =>
589+
{
590+
await SendToRelatedTransportAsync(msg, ct).ConfigureAwait(false);
591+
})(request, cancellationToken).ConfigureAwait(false);
589592

590593
// Now that the request has been sent, register for cancellation. If we registered before,
591594
// a cancellation request could arrive before the server knew about that request ID, in which

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Diagnostics;
1010
using System.Net;
1111
using System.Security.Claims;
12+
using System.Text.Json.Nodes;
1213

1314
namespace ModelContextProtocol.AspNetCore.Tests;
1415

@@ -290,6 +291,225 @@ public async Task LongRunningToolCall_DoesNotTimeout_WhenNoEventStreamStore()
290291

291292
}
292293

294+
[Fact]
295+
public async Task IncomingFilter_SeesClientRequests()
296+
{
297+
var observedMethods = new List<string>();
298+
299+
Builder.Services.AddMcpServer()
300+
.WithHttpTransport(ConfigureStateless)
301+
.WithMessageFilters(filters => filters.AddIncomingFilter((next) => async (context, cancellationToken) =>
302+
{
303+
if (context.JsonRpcMessage is JsonRpcRequest request)
304+
{
305+
observedMethods.Add(request.Method);
306+
}
307+
308+
await next(context, cancellationToken);
309+
}))
310+
.WithTools<EchoHttpContextUserTools>();
311+
312+
Builder.Services.AddHttpContextAccessor();
313+
314+
await using var app = Builder.Build();
315+
app.MapMcp();
316+
await app.StartAsync(TestContext.Current.CancellationToken);
317+
318+
await using var client = await ConnectAsync();
319+
320+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
321+
await client.CallToolAsync("echo_with_user_name",
322+
new Dictionary<string, object?> { ["message"] = "hi" },
323+
cancellationToken: TestContext.Current.CancellationToken);
324+
325+
Assert.Contains(RequestMethods.Initialize, observedMethods);
326+
Assert.Contains(RequestMethods.ToolsList, observedMethods);
327+
Assert.Contains(RequestMethods.ToolsCall, observedMethods);
328+
}
329+
330+
[Fact]
331+
public async Task OutgoingFilter_SeesResponses()
332+
{
333+
var observedMessageTypes = new List<string>();
334+
335+
Builder.Services.AddMcpServer()
336+
.WithHttpTransport(ConfigureStateless)
337+
.WithMessageFilters(filters => filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
338+
{
339+
var typeName = context.JsonRpcMessage switch
340+
{
341+
JsonRpcResponse r when r.Result is JsonObject obj && obj.ContainsKey("protocolVersion") => "initialize-response",
342+
JsonRpcResponse r when r.Result is JsonObject obj && obj.ContainsKey("tools") => "tools-list-response",
343+
JsonRpcResponse r when r.Result is JsonObject obj && obj.ContainsKey("content") => "tool-call-response",
344+
_ => null,
345+
};
346+
347+
if (typeName is not null)
348+
{
349+
observedMessageTypes.Add(typeName);
350+
}
351+
352+
await next(context, cancellationToken);
353+
}))
354+
.WithTools<ClaimsPrincipalTools>();
355+
356+
await using var app = Builder.Build();
357+
app.MapMcp();
358+
await app.StartAsync(TestContext.Current.CancellationToken);
359+
360+
await using var client = await ConnectAsync();
361+
362+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
363+
await client.CallToolAsync("echo_claims_principal",
364+
new Dictionary<string, object?> { ["message"] = "hi" },
365+
cancellationToken: TestContext.Current.CancellationToken);
366+
367+
Assert.Contains("initialize-response", observedMessageTypes);
368+
Assert.Contains("tools-list-response", observedMessageTypes);
369+
Assert.Contains("tool-call-response", observedMessageTypes);
370+
}
371+
372+
[Fact]
373+
public async Task OutgoingFilter_SeesServerOriginatedRequests()
374+
{
375+
Assert.SkipWhen(Stateless, "Server-originated requests are not supported in stateless mode.");
376+
377+
var observedMethods = new List<string>();
378+
379+
Builder.Services.AddMcpServer()
380+
.WithHttpTransport(ConfigureStateless)
381+
.WithMessageFilters(filters => filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
382+
{
383+
if (context.JsonRpcMessage is JsonRpcRequest request)
384+
{
385+
observedMethods.Add(request.Method);
386+
}
387+
388+
await next(context, cancellationToken);
389+
}))
390+
.WithTools<SamplingRegressionTools>();
391+
392+
await using var app = Builder.Build();
393+
app.MapMcp();
394+
await app.StartAsync(TestContext.Current.CancellationToken);
395+
396+
var clientOptions = new McpClientOptions
397+
{
398+
Capabilities = new() { Sampling = new() },
399+
Handlers = new()
400+
{
401+
SamplingHandler = (_, _, _) => new(new CreateMessageResult
402+
{
403+
Content = [new TextContentBlock { Text = "sampled response" }],
404+
Model = "test-model",
405+
}),
406+
},
407+
};
408+
409+
await using var client = await ConnectAsync(clientOptions: clientOptions);
410+
411+
await client.CallToolAsync("sampling-tool",
412+
new Dictionary<string, object?> { ["prompt"] = "Hello" },
413+
cancellationToken: TestContext.Current.CancellationToken);
414+
415+
Assert.Contains(RequestMethods.SamplingCreateMessage, observedMethods);
416+
}
417+
418+
[Fact]
419+
public async Task OutgoingFilter_MultipleFilters_ExecuteInOrder()
420+
{
421+
var executionOrder = new List<string>();
422+
423+
Builder.Services.AddMcpServer()
424+
.WithHttpTransport(ConfigureStateless)
425+
.WithMessageFilters(filters =>
426+
{
427+
filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
428+
{
429+
if (context.JsonRpcMessage is JsonRpcResponse r && r.Result is JsonObject obj && obj.ContainsKey("tools"))
430+
{
431+
executionOrder.Add("filter1-before");
432+
}
433+
434+
await next(context, cancellationToken);
435+
436+
if (context.JsonRpcMessage is JsonRpcResponse r2 && r2.Result is JsonObject obj2 && obj2.ContainsKey("tools"))
437+
{
438+
executionOrder.Add("filter1-after");
439+
}
440+
});
441+
442+
filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
443+
{
444+
if (context.JsonRpcMessage is JsonRpcResponse r && r.Result is JsonObject obj && obj.ContainsKey("tools"))
445+
{
446+
executionOrder.Add("filter2-before");
447+
}
448+
449+
await next(context, cancellationToken);
450+
451+
if (context.JsonRpcMessage is JsonRpcResponse r2 && r2.Result is JsonObject obj2 && obj2.ContainsKey("tools"))
452+
{
453+
executionOrder.Add("filter2-after");
454+
}
455+
});
456+
})
457+
.WithTools<ClaimsPrincipalTools>();
458+
459+
await using var app = Builder.Build();
460+
app.MapMcp();
461+
await app.StartAsync(TestContext.Current.CancellationToken);
462+
463+
await using var client = await ConnectAsync();
464+
465+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
466+
467+
Assert.Equal(["filter1-before", "filter2-before", "filter2-after", "filter1-after"], executionOrder);
468+
}
469+
470+
[Fact]
471+
public async Task OutgoingFilter_CanSendAdditionalMessages()
472+
{
473+
Builder.Services.AddMcpServer()
474+
.WithHttpTransport(ConfigureStateless)
475+
.WithMessageFilters(filters => filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
476+
{
477+
if (context.JsonRpcMessage is JsonRpcResponse response &&
478+
response.Result is JsonObject result && result.ContainsKey("tools"))
479+
{
480+
var extraNotification = new JsonRpcNotification
481+
{
482+
Method = "test/extra",
483+
Params = new JsonObject { ["message"] = "injected" },
484+
Context = new JsonRpcMessageContext { RelatedTransport = context.JsonRpcMessage.Context?.RelatedTransport },
485+
};
486+
487+
await next(new MessageContext(context.Server, extraNotification), cancellationToken);
488+
}
489+
490+
await next(context, cancellationToken);
491+
}))
492+
.WithTools<ClaimsPrincipalTools>();
493+
494+
await using var app = Builder.Build();
495+
app.MapMcp();
496+
await app.StartAsync(TestContext.Current.CancellationToken);
497+
498+
await using var client = await ConnectAsync();
499+
500+
var extraReceived = new TaskCompletionSource<string?>(TaskCreationOptions.RunContinuationsAsynchronously);
501+
await using var registration = client.RegisterNotificationHandler("test/extra", (notification, _) =>
502+
{
503+
extraReceived.TrySetResult(notification.Params?["message"]?.GetValue<string>());
504+
return default;
505+
});
506+
507+
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
508+
509+
var extraMessage = await extraReceived.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken);
510+
Assert.Equal("injected", extraMessage);
511+
}
512+
293513
private ClaimsPrincipal CreateUser(string name)
294514
=> new(new ClaimsIdentity(
295515
[new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)],

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,46 @@ public async Task AddOutgoingMessageFilter_Can_Send_Additional_Messages()
428428
Assert.Equal("extra", extraMessage);
429429
}
430430

431+
[Fact]
432+
public async Task AddOutgoingMessageFilter_Sees_ServerOriginatedRequests()
433+
{
434+
var observedMethods = new List<string>();
435+
436+
McpServerBuilder
437+
.WithMessageFilters(filters => filters.AddOutgoingFilter((next) => async (context, cancellationToken) =>
438+
{
439+
if (context.JsonRpcMessage is JsonRpcRequest request)
440+
{
441+
observedMethods.Add(request.Method);
442+
}
443+
444+
await next(context, cancellationToken);
445+
}))
446+
.WithTools<SamplingTool>();
447+
448+
StartServer();
449+
450+
var clientOptions = new McpClientOptions
451+
{
452+
Capabilities = new() { Sampling = new() },
453+
Handlers = new()
454+
{
455+
SamplingHandler = (_, _, _) => new(new CreateMessageResult
456+
{
457+
Content = [new TextContentBlock { Text = "sampled" }],
458+
Model = "test-model",
459+
}),
460+
},
461+
};
462+
463+
await using McpClient client = await CreateMcpClientForServer(clientOptions);
464+
465+
await client.CallToolAsync("sampling-tool", new Dictionary<string, object?> { ["prompt"] = "Hello" },
466+
cancellationToken: TestContext.Current.CancellationToken);
467+
468+
Assert.Contains(RequestMethods.SamplingCreateMessage, observedMethods);
469+
}
470+
431471
[Fact]
432472
public async Task AddIncomingMessageFilter_Items_Flow_To_Request_Filters()
433473
{
@@ -644,7 +684,6 @@ public async Task AddIncomingMessageFilter_Items_Flow_Through_Multiple_Request_F
644684
Assert.Equal("modifiedByFilter1", observedValues[1]);
645685
}
646686

647-
[McpServerToolType]
648687
public sealed class TestTool
649688
{
650689
[McpServerTool]
@@ -654,7 +693,6 @@ public static string TestToolMethod()
654693
}
655694
}
656695

657-
[McpServerPromptType]
658696
public sealed class TestPrompt
659697
{
660698
[McpServerPrompt]
@@ -668,7 +706,6 @@ public static Task<GetPromptResult> TestPromptMethod()
668706
}
669707
}
670708

671-
[McpServerResourceType]
672709
public sealed class TestResource
673710
{
674711
[McpServerResource(UriTemplate = "test://resource/{id}")]
@@ -678,7 +715,6 @@ public static string TestResourceMethod(string id)
678715
}
679716
}
680717

681-
[McpServerToolType]
682718
public sealed class ProgressTool
683719
{
684720
[McpServerTool(Name = "progress-tool")]
@@ -708,7 +744,6 @@ public static async Task<string> ReportProgress(
708744
}
709745
}
710746

711-
[McpServerToolType]
712747
public sealed class SimpleTool
713748
{
714749
[McpServerTool(Name = "simple-tool")]
@@ -717,4 +752,19 @@ public static string Execute()
717752
return "success";
718753
}
719754
}
755+
756+
public sealed class SamplingTool
757+
{
758+
[McpServerTool(Name = "sampling-tool")]
759+
public static async Task<string> SampleAsync(McpServer server, string prompt, CancellationToken cancellationToken)
760+
{
761+
var result = await server.SampleAsync(new CreateMessageRequestParams
762+
{
763+
Messages = [new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = prompt }] }],
764+
MaxTokens = 100,
765+
}, cancellationToken);
766+
767+
return $"Sampled: {Assert.IsType<TextContentBlock>(Assert.Single(result.Content)).Text}";
768+
}
769+
}
720770
}

0 commit comments

Comments
 (0)