|
4 | 4 | using ModelContextProtocol.Protocol; |
5 | 5 | using ModelContextProtocol.Server; |
6 | 6 | using ModelContextProtocol.Tests.Utils; |
| 7 | +using System.Security.Claims; |
7 | 8 | using System.Text.Json.Nodes; |
8 | 9 |
|
9 | 10 | namespace ModelContextProtocol.Tests.Configuration; |
@@ -415,6 +416,214 @@ public async Task AddOutgoingMessageFilter_Can_Send_Additional_Messages() |
415 | 416 | Assert.Equal("extra", extraMessage); |
416 | 417 | } |
417 | 418 |
|
| 419 | + [Fact] |
| 420 | + public async Task AddIncomingMessageFilter_Items_Flow_To_Request_Filters() |
| 421 | + { |
| 422 | + string? capturedValue = null; |
| 423 | + |
| 424 | + McpServerBuilder |
| 425 | + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => |
| 426 | + { |
| 427 | + // Set an item in the message filter |
| 428 | + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) |
| 429 | + { |
| 430 | + context.Items["messageFilterKey"] = "messageFilterValue"; |
| 431 | + } |
| 432 | + await next(context, cancellationToken); |
| 433 | + }) |
| 434 | + .AddListToolsFilter((next) => async (request, cancellationToken) => |
| 435 | + { |
| 436 | + // Read the item in the request-specific filter |
| 437 | + if (request.Items.TryGetValue("messageFilterKey", out var value)) |
| 438 | + { |
| 439 | + capturedValue = value as string; |
| 440 | + } |
| 441 | + return await next(request, cancellationToken); |
| 442 | + }) |
| 443 | + .WithTools<TestTool>(); |
| 444 | + |
| 445 | + StartServer(); |
| 446 | + |
| 447 | + await using McpClient client = await CreateMcpClientForServer(); |
| 448 | + |
| 449 | + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); |
| 450 | + |
| 451 | + Assert.Equal("messageFilterValue", capturedValue); |
| 452 | + } |
| 453 | + |
| 454 | + [Fact] |
| 455 | + public async Task AddIncomingMessageFilter_Items_Flow_To_CallTool_Handler() |
| 456 | + { |
| 457 | + object? capturedValue = null; |
| 458 | + |
| 459 | + McpServerBuilder |
| 460 | + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => |
| 461 | + { |
| 462 | + // Set an item in the message filter for CallTool requests |
| 463 | + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsCall) |
| 464 | + { |
| 465 | + context.Items["toolContextKey"] = 42; |
| 466 | + } |
| 467 | + await next(context, cancellationToken); |
| 468 | + }) |
| 469 | + .AddCallToolFilter((next) => async (request, cancellationToken) => |
| 470 | + { |
| 471 | + // Read the item in the call tool filter |
| 472 | + if (request.Items.TryGetValue("toolContextKey", out var value)) |
| 473 | + { |
| 474 | + capturedValue = value; |
| 475 | + } |
| 476 | + return await next(request, cancellationToken); |
| 477 | + }) |
| 478 | + .WithTools<SimpleTool>(); |
| 479 | + |
| 480 | + StartServer(); |
| 481 | + |
| 482 | + await using McpClient client = await CreateMcpClientForServer(); |
| 483 | + |
| 484 | + await client.CallToolAsync("simple-tool", cancellationToken: TestContext.Current.CancellationToken); |
| 485 | + |
| 486 | + Assert.Equal(42, capturedValue); |
| 487 | + } |
| 488 | + |
| 489 | + [Fact] |
| 490 | + public async Task AddIncomingMessageFilter_User_Flows_To_CallTool_Handler() |
| 491 | + { |
| 492 | + ClaimsPrincipal? capturedUser = null; |
| 493 | + |
| 494 | + McpServerBuilder |
| 495 | + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => |
| 496 | + { |
| 497 | + // Set a custom user in the message filter for CallTool requests |
| 498 | + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsCall) |
| 499 | + { |
| 500 | + var claims = new[] { new Claim(ClaimTypes.Name, "TestUser"), new Claim(ClaimTypes.Role, "Admin") }; |
| 501 | + var identity = new ClaimsIdentity(claims, "TestAuth"); |
| 502 | + context.User = new ClaimsPrincipal(identity); |
| 503 | + } |
| 504 | + await next(context, cancellationToken); |
| 505 | + }) |
| 506 | + .AddCallToolFilter((next) => async (request, cancellationToken) => |
| 507 | + { |
| 508 | + // Read the user in the call tool filter |
| 509 | + capturedUser = request.User; |
| 510 | + return await next(request, cancellationToken); |
| 511 | + }) |
| 512 | + .WithTools<SimpleTool>(); |
| 513 | + |
| 514 | + StartServer(); |
| 515 | + |
| 516 | + await using McpClient client = await CreateMcpClientForServer(); |
| 517 | + |
| 518 | + await client.CallToolAsync("simple-tool", cancellationToken: TestContext.Current.CancellationToken); |
| 519 | + |
| 520 | + Assert.NotNull(capturedUser); |
| 521 | + Assert.Equal("TestUser", capturedUser.Identity?.Name); |
| 522 | + Assert.True(capturedUser.IsInRole("Admin")); |
| 523 | + } |
| 524 | + |
| 525 | + [Fact] |
| 526 | + public async Task AddIncomingMessageFilter_Items_Preserved_When_Context_Replaced() |
| 527 | + { |
| 528 | + object? firstFilterValue = null; |
| 529 | + object? secondFilterValue = null; |
| 530 | + |
| 531 | + McpServerBuilder |
| 532 | + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => |
| 533 | + { |
| 534 | + // First filter sets an item |
| 535 | + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) |
| 536 | + { |
| 537 | + context.Items["firstFilterKey"] = "firstFilterValue"; |
| 538 | + } |
| 539 | + await next(context, cancellationToken); |
| 540 | + }) |
| 541 | + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => |
| 542 | + { |
| 543 | + // Second filter creates a new context with a new JsonRpcRequest and adds an item |
| 544 | + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) |
| 545 | + { |
| 546 | + var newRequest = new JsonRpcRequest |
| 547 | + { |
| 548 | + Id = request.Id, |
| 549 | + Method = RequestMethods.ToolsList, |
| 550 | + Params = request.Params, |
| 551 | + Context = new JsonRpcMessageContext { RelatedTransport = request.Context?.RelatedTransport }, |
| 552 | + }; |
| 553 | + |
| 554 | + var newContext = new MessageContext(context.Server, newRequest); |
| 555 | + newContext.Items["secondFilterKey"] = "secondFilterValue"; |
| 556 | + |
| 557 | + await next(newContext, cancellationToken); |
| 558 | + return; |
| 559 | + } |
| 560 | + await next(context, cancellationToken); |
| 561 | + }) |
| 562 | + .AddListToolsFilter((next) => async (request, cancellationToken) => |
| 563 | + { |
| 564 | + // Request filter should see items from message filters |
| 565 | + request.Items.TryGetValue("firstFilterKey", out firstFilterValue); |
| 566 | + request.Items.TryGetValue("secondFilterKey", out secondFilterValue); |
| 567 | + return await next(request, cancellationToken); |
| 568 | + }) |
| 569 | + .WithTools<TestTool>(); |
| 570 | + |
| 571 | + StartServer(); |
| 572 | + |
| 573 | + await using McpClient client = await CreateMcpClientForServer(); |
| 574 | + |
| 575 | + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); |
| 576 | + |
| 577 | + Assert.Null(firstFilterValue); |
| 578 | + Assert.Equal("secondFilterValue", secondFilterValue); |
| 579 | + } |
| 580 | + |
| 581 | + [Fact] |
| 582 | + public async Task AddIncomingMessageFilter_Items_Flow_Through_Multiple_Request_Filters() |
| 583 | + { |
| 584 | + var observedValues = new List<string>(); |
| 585 | + |
| 586 | + McpServerBuilder |
| 587 | + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => |
| 588 | + { |
| 589 | + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) |
| 590 | + { |
| 591 | + context.Items["sharedKey"] = "fromMessageFilter"; |
| 592 | + } |
| 593 | + await next(context, cancellationToken); |
| 594 | + }) |
| 595 | + .AddListToolsFilter((next) => async (request, cancellationToken) => |
| 596 | + { |
| 597 | + // First request filter reads and modifies |
| 598 | + if (request.Items.TryGetValue("sharedKey", out var value)) |
| 599 | + { |
| 600 | + observedValues.Add((string)value!); |
| 601 | + request.Items["sharedKey"] = "modifiedByFilter1"; |
| 602 | + } |
| 603 | + return await next(request, cancellationToken); |
| 604 | + }) |
| 605 | + .AddListToolsFilter((next) => async (request, cancellationToken) => |
| 606 | + { |
| 607 | + // Second request filter should see modified value |
| 608 | + if (request.Items.TryGetValue("sharedKey", out var value)) |
| 609 | + { |
| 610 | + observedValues.Add((string)value!); |
| 611 | + } |
| 612 | + return await next(request, cancellationToken); |
| 613 | + }) |
| 614 | + .WithTools<TestTool>(); |
| 615 | + |
| 616 | + StartServer(); |
| 617 | + |
| 618 | + await using McpClient client = await CreateMcpClientForServer(); |
| 619 | + |
| 620 | + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); |
| 621 | + |
| 622 | + Assert.Equal(2, observedValues.Count); |
| 623 | + Assert.Equal("fromMessageFilter", observedValues[0]); |
| 624 | + Assert.Equal("modifiedByFilter1", observedValues[1]); |
| 625 | + } |
| 626 | + |
418 | 627 | [McpServerToolType] |
419 | 628 | public sealed class TestTool |
420 | 629 | { |
@@ -478,4 +687,14 @@ public static async Task<string> ReportProgress( |
478 | 687 | return "done"; |
479 | 688 | } |
480 | 689 | } |
| 690 | + |
| 691 | + [McpServerToolType] |
| 692 | + public sealed class SimpleTool |
| 693 | + { |
| 694 | + [McpServerTool(Name = "simple-tool")] |
| 695 | + public static string Execute() |
| 696 | + { |
| 697 | + return "success"; |
| 698 | + } |
| 699 | + } |
481 | 700 | } |
0 commit comments