diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs index 3f5870700..44946ea80 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -30,15 +30,6 @@ public void Configure(McpServerOptions options) public void PostConfigure(string? name, McpServerOptions options) { - CheckListToolsFilter(options); - CheckCallToolFilter(options); - - CheckListResourcesFilter(options); - CheckListResourceTemplatesFilter(options); - CheckReadResourceFilter(options); - - CheckListPromptsFilter(options); - CheckGetPromptFilter(options); } private void ConfigureListToolsFilter(McpServerOptions options) @@ -59,26 +50,6 @@ await FilterAuthorizedItemsAsync( }); } - private static void CheckListToolsFilter(McpServerOptions options) - { - options.Filters.Request.ListToolsFilters.Add(next => - { - var toolCollection = options.ToolCollection; - return async (context, cancellationToken) => - { - var result = await next(context, cancellationToken); - - if (HasAuthorizationMetadata(result.Tools.Select(tool => toolCollection is not null && toolCollection.TryGetPrimitive(tool.Name, out var serverTool) ? serverTool : null)) - && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) - { - throw new InvalidOperationException("Authorization filter was not invoked for tools/list operation, but authorization metadata was found on the tools. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); - } - - return result; - }; - }); - } - private void ConfigureCallToolFilter(McpServerOptions options) { options.Filters.Request.CallToolFilters.Add(next => async (context, cancellationToken) => @@ -95,20 +66,6 @@ private void ConfigureCallToolFilter(McpServerOptions options) }); } - private static void CheckCallToolFilter(McpServerOptions options) - { - options.Filters.Request.CallToolFilters.Add(next => async (context, cancellationToken) => - { - if (HasAuthorizationMetadata(context.MatchedPrimitive) - && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) - { - throw new InvalidOperationException("Authorization filter was not invoked for tools/call operation, but authorization metadata was found on the tool. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); - } - - return await next(context, cancellationToken); - }); - } - private void ConfigureListResourcesFilter(McpServerOptions options) { options.Filters.Request.ListResourcesFilters.Add(next => @@ -127,26 +84,6 @@ await FilterAuthorizedItemsAsync( }); } - private static void CheckListResourcesFilter(McpServerOptions options) - { - options.Filters.Request.ListResourcesFilters.Add(next => - { - var resourceCollection = options.ResourceCollection; - return async (context, cancellationToken) => - { - var result = await next(context, cancellationToken); - - if (HasAuthorizationMetadata(result.Resources.Select(resource => resourceCollection is not null && resourceCollection.TryGetPrimitive(resource.Uri, out var serverResource) ? serverResource : null)) - && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) - { - throw new InvalidOperationException("Authorization filter was not invoked for resources/list operation, but authorization metadata was found on the resources. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); - } - - return result; - }; - }); - } - private void ConfigureListResourceTemplatesFilter(McpServerOptions options) { options.Filters.Request.ListResourceTemplatesFilters.Add(next => @@ -165,26 +102,6 @@ await FilterAuthorizedItemsAsync( }); } - private static void CheckListResourceTemplatesFilter(McpServerOptions options) - { - options.Filters.Request.ListResourceTemplatesFilters.Add(next => - { - var resourceCollection = options.ResourceCollection; - return async (context, cancellationToken) => - { - var result = await next(context, cancellationToken); - - if (HasAuthorizationMetadata(result.ResourceTemplates.Select(resourceTemplate => resourceCollection is not null && resourceCollection.TryGetPrimitive(resourceTemplate.UriTemplate, out var serverResource) ? serverResource : null)) - && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) - { - throw new InvalidOperationException("Authorization filter was not invoked for resources/templates/list operation, but authorization metadata was found on the resource templates. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); - } - - return result; - }; - }); - } - private void ConfigureReadResourceFilter(McpServerOptions options) { options.Filters.Request.ReadResourceFilters.Add(next => async (context, cancellationToken) => @@ -201,20 +118,6 @@ private void ConfigureReadResourceFilter(McpServerOptions options) }); } - private static void CheckReadResourceFilter(McpServerOptions options) - { - options.Filters.Request.ReadResourceFilters.Add(next => async (context, cancellationToken) => - { - if (HasAuthorizationMetadata(context.MatchedPrimitive) - && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) - { - throw new InvalidOperationException("Authorization filter was not invoked for resources/read operation, but authorization metadata was found on the resource. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); - } - - return await next(context, cancellationToken); - }); - } - private void ConfigureListPromptsFilter(McpServerOptions options) { options.Filters.Request.ListPromptsFilters.Add(next => @@ -233,26 +136,6 @@ await FilterAuthorizedItemsAsync( }); } - private static void CheckListPromptsFilter(McpServerOptions options) - { - options.Filters.Request.ListPromptsFilters.Add(next => - { - var promptCollection = options.PromptCollection; - return async (context, cancellationToken) => - { - var result = await next(context, cancellationToken); - - if (HasAuthorizationMetadata(result.Prompts.Select(prompt => promptCollection is not null && promptCollection.TryGetPrimitive(prompt.Name, out var serverPrompt) ? serverPrompt : null)) - && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) - { - throw new InvalidOperationException("Authorization filter was not invoked for prompts/list operation, but authorization metadata was found on the prompts. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); - } - - return result; - }; - }); - } - private void ConfigureGetPromptFilter(McpServerOptions options) { options.Filters.Request.GetPromptFilters.Add(next => async (context, cancellationToken) => @@ -269,20 +152,6 @@ private void ConfigureGetPromptFilter(McpServerOptions options) }); } - private static void CheckGetPromptFilter(McpServerOptions options) - { - options.Filters.Request.GetPromptFilters.Add(next => async (context, cancellationToken) => - { - if (HasAuthorizationMetadata(context.MatchedPrimitive) - && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) - { - throw new InvalidOperationException("Authorization filter was not invoked for prompts/get operation, but authorization metadata was found on the prompt. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); - } - - return await next(context, cancellationToken); - }); - } - /// /// Filters a collection of items based on authorization policies in their metadata. /// For list operations where we need to filter results by authorization. @@ -338,7 +207,7 @@ private async ValueTask GetAuthorizationResultAsync( /// The authorization policy provider. /// The endpoint metadata collection. /// The combined authorization policy, or null if no authorization is required. - private static async ValueTask CombineAsync(IAuthorizationPolicyProvider policyProvider, IReadOnlyList endpointMetadata) + internal static async ValueTask CombineAsync(IAuthorizationPolicyProvider policyProvider, IReadOnlyList endpointMetadata) { // https://github.com/dotnet/aspnetcore/issues/63365 tracks adding this as public API to AuthorizationPolicy itself. // Copied from https://github.com/dotnet/aspnetcore/blob/9f2977bf9cfb539820983bda3bedf81c8cda9f20/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs#L116-L138 @@ -374,7 +243,7 @@ private async ValueTask GetAuthorizationResultAsync( : AuthorizationPolicy.Combine(policy, reqPolicyBuilder.Build()); } - private static bool HasAuthorizationMetadata([NotNullWhen(true)] IMcpServerPrimitive? primitive) + internal static bool HasAuthorizationMetadata([NotNullWhen(true)] IMcpServerPrimitive? primitive) { // If no primitive was found for this request or there is IAllowAnonymous metadata anywhere on the class or method, // the request should go through as normal. @@ -385,7 +254,4 @@ private static bool HasAuthorizationMetadata([NotNullWhen(true)] IMcpServerPrimi return primitive.Metadata.Any(static m => m is IAuthorizeData or AuthorizationPolicy or IAuthorizationRequirementData); } - - private static bool HasAuthorizationMetadata(IEnumerable primitives) - => primitives.Any(HasAuthorizationMetadata); -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index bcdf53584..cb0f5a421 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -64,6 +64,10 @@ public static IMcpServerBuilder AddAuthorizationFilters(this IMcpServerBuilder b // Allow the authorization filters to get added multiple times in case other middleware changes the matched primitive. builder.Services.AddTransient, AuthorizationFilterSetup>(); + // Signal to the HTTP transport that authorization filters are handling access control, + // so the pre-flight incremental scope consent check (SEP-835) should be skipped. + builder.Services.Configure(static o => o.AuthorizationFiltersRegistered = true); + return builder; } diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 648cb86df..6bfa16553 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -188,4 +188,13 @@ public class HttpServerTransportOptions /// Gets or sets the time provider that's used for testing the . /// public TimeProvider TimeProvider { get; set; } = TimeProvider.System; + + /// + /// Gets a value indicating whether authorization filters have been registered via + /// AddAuthorizationFilters. + /// When , the MCP filter pipeline handles authorization (hiding unauthorized primitives and returning MCP errors). + /// When (the default), the HTTP transport performs a pre-flight authorization check that returns + /// HTTP 403 with WWW-Authenticate: Bearer error="insufficient_scope" for incremental scope consent (SEP-835). + /// + internal bool AuthorizationFiltersRegistered { get; set; } } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index ec28eff84..3d3779ec6 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -1,6 +1,8 @@ -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -41,6 +43,9 @@ internal sealed class StreamableHttpHandler( private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); + private static readonly JsonTypeInfo s_callToolParamsTypeInfo = GetRequiredJsonTypeInfo(); + private static readonly JsonTypeInfo s_getPromptParamsTypeInfo = GetRequiredJsonTypeInfo(); + private static readonly JsonTypeInfo s_readResourceParamsTypeInfo = GetRequiredJsonTypeInfo(); private static bool AllowNewSessionForNonInitializeRequests { get; } = AppContext.TryGetSwitch("ModelContextProtocol.AspNetCore.AllowNewSessionForNonInitializeRequests", out var enabled) && enabled; @@ -87,6 +92,11 @@ await WriteJsonRpcErrorAsync(context, await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); + if (await TryHandleInsufficientScopeAsync(context, session, message)) + { + return; + } + InitializeSseResponse(context); var wroteResponse = await session.Transport.HandlePostRequestAsync(message, context.Response.Body, context.RequestAborted); if (!wroteResponse) @@ -463,6 +473,127 @@ private static Task WriteJsonRpcErrorAsync(HttpContext context, string errorMess return Results.Json(jsonRpcError, s_errorTypeInfo, statusCode: statusCode).ExecuteAsync(context); } + /// + /// Performs a pre-flight authorization check for invocation requests (tools/call, prompts/get, resources/read) + /// when has not been called. + /// If the request targets a primitive with + /// metadata and the caller is not authorized, writes an HTTP 403 response with a + /// WWW-Authenticate: Bearer error="insufficient_scope" header to trigger incremental scope consent (SEP-835). + /// + /// if a 403 response was written and request processing should stop; otherwise . + private async ValueTask TryHandleInsufficientScopeAsync(HttpContext context, StreamableHttpSession session, JsonRpcMessage message) + { + // Only applicable when AddAuthorizationFilters has NOT been called. + // If it was called, the MCP filter pipeline handles authorization (hiding + MCP errors). + if (httpServerTransportOptions.Value.AuthorizationFiltersRegistered) + { + return false; + } + + // Only handle invocation requests that target a specific primitive. + if (message is not JsonRpcRequest request) + { + return false; + } + + var serverOptions = session.Server.ServerOptions; + IMcpServerPrimitive? primitive = null; + + switch (request.Method) + { + case RequestMethods.ToolsCall: + { + var toolParams = request.Params is { } p ? System.Text.Json.JsonSerializer.Deserialize(p, s_callToolParamsTypeInfo) : null; + if (toolParams?.Name is { } toolName && serverOptions.ToolCollection is { } tools + && tools.TryGetPrimitive(toolName, out var tool)) + { + primitive = tool; + } + break; + } + case RequestMethods.PromptsGet: + { + var promptParams = request.Params is { } p ? System.Text.Json.JsonSerializer.Deserialize(p, s_getPromptParamsTypeInfo) : null; + if (promptParams?.Name is { } promptName && serverOptions.PromptCollection is { } prompts + && prompts.TryGetPrimitive(promptName, out var prompt)) + { + primitive = prompt; + } + break; + } + case RequestMethods.ResourcesRead: + { + var resourceParams = request.Params is { } p ? System.Text.Json.JsonSerializer.Deserialize(p, s_readResourceParamsTypeInfo) : null; + if (resourceParams?.Uri is { } resourceUri && serverOptions.ResourceCollection is { } resources) + { + // First try an exact match, then fall back to URI template matching. + if (resources.TryGetPrimitive(resourceUri, out var resource) && !resource.IsTemplated) + { + primitive = resource; + } + else + { + foreach (var resourceTemplate in resources) + { + if (resourceTemplate.IsMatch(resourceUri)) + { + primitive = resourceTemplate; + break; + } + } + } + } + break; + } + default: + return false; + } + + if (!AuthorizationFilterSetup.HasAuthorizationMetadata(primitive)) + { + return false; + } + + // Evaluate the authorization policy for this primitive. + var policyProvider = context.RequestServices.GetService(); + if (policyProvider is null) + { + // No authorization infrastructure configured; skip the pre-flight check. + return false; + } + + var policy = await AuthorizationFilterSetup.CombineAsync(policyProvider, primitive.Metadata); + if (policy is null) + { + return false; + } + + var authService = context.RequestServices.GetRequiredService(); + var authResult = await authService.AuthorizeAsync(context.User ?? new ClaimsPrincipal(new ClaimsIdentity()), context, policy); + if (authResult.Succeeded) + { + return false; + } + + // Authorization failed. Build a WWW-Authenticate header with error="insufficient_scope". + // Extract the scope from IAuthorizeData.Roles (the standard pattern for incremental scope consent). + var scope = primitive.Metadata + .OfType() + .Select(static a => a.Roles) + .FirstOrDefault(static r => !string.IsNullOrEmpty(r)); + + // Build the resource_metadata URL using the default well-known path for this endpoint. + var resourceMetadataUri = $"{context.Request.Scheme}://{context.Request.Host}{context.Request.PathBase}/.well-known/oauth-protected-resource{context.Request.Path}"; + + var wwwAuthenticate = string.IsNullOrEmpty(scope) + ? $"Bearer error=\"insufficient_scope\", resource_metadata=\"{resourceMetadataUri}\"" + : $"Bearer error=\"insufficient_scope\", scope=\"{scope}\", resource_metadata=\"{resourceMetadataUri}\""; + + context.Response.Headers[HeaderNames.WWWAuthenticate] = wwwAuthenticate; + await WriteJsonRpcErrorAsync(context, "Forbidden: Insufficient scope.", StatusCodes.Status403Forbidden, (int)McpErrorCode.InvalidRequest); + return true; + } + internal static void InitializeSseResponse(HttpContext context) { context.Response.Headers.ContentType = "text/event-stream"; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs index 76a7201d8..111ce7e64 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -8,6 +8,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using System.ComponentModel; +using System.Net.Http; using System.Security.Claims; namespace ModelContextProtocol.AspNetCore.Tests; @@ -270,132 +271,109 @@ public async Task ListResources_Anonymous_OnlyReturnsAnonymousResources() } [Fact] - public async Task ListTools_WithoutAuthFilters_ThrowsInvalidOperationException() + public async Task ListTools_WithoutAuthFilters_ReturnsAllTools() { + // Without AddAuthorizationFilters(), all tools are visible in listings (incremental consent model). await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => - await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal("Request failed (remote): An error occurred.", exception.Message); - Assert.Contains(MockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Warning && - log.Exception is InvalidOperationException && - log.Exception.Message.Contains("Authorization filter was not invoked for tools/list operation") && - log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + // All tools should be visible regardless of [Authorize] attributes when AddAuthorizationFilters() is not called. + Assert.Equal(3, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["admin_tool", "anonymous_tool", "authorized_tool"], toolNames); } [Fact] - public async Task CallTool_WithoutAuthFilters_ReturnsError() + public async Task CallTool_WithoutAuthFilters_ReturnsForbidden() { + // Without AddAuthorizationFilters(), calling an [Authorize] tool returns HTTP 403 (incremental consent model). await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); var client = await ConnectAsync(); - var toolResult = await client.CallToolAsync( + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( "authorized_tool", new Dictionary { ["message"] = "test" }, - cancellationToken: TestContext.Current.CancellationToken); - - Assert.True(toolResult.IsError); + cancellationToken: TestContext.Current.CancellationToken)); - var errorContent = Assert.IsType(Assert.Single(toolResult.Content)); - Assert.Equal("An error occurred invoking 'authorized_tool'.", errorContent.Text); - Assert.Contains(MockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Error && - log.Exception is InvalidOperationException && - log.Exception.Message.Contains("Authorization filter was not invoked for tools/call operation") && - log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + Assert.Equal(System.Net.HttpStatusCode.Forbidden, exception.StatusCode); } [Fact] - public async Task ListPrompts_WithoutAuthFilters_ThrowsInvalidOperationException() + public async Task ListPrompts_WithoutAuthFilters_ReturnsAllPrompts() { + // Without AddAuthorizationFilters(), all prompts are visible in listings (incremental consent model). await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => - await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken)); + var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal("Request failed (remote): An error occurred.", exception.Message); - Assert.Contains(MockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Warning && - log.Exception is InvalidOperationException && - log.Exception.Message.Contains("Authorization filter was not invoked for prompts/list operation") && - log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + // All prompts should be visible regardless of [Authorize] attributes when AddAuthorizationFilters() is not called. + Assert.Equal(2, prompts.Count); + var promptNames = prompts.Select(p => p.Name).OrderBy(n => n).ToList(); + Assert.Equal(["anonymous_prompt", "authorized_prompt"], promptNames); } [Fact] - public async Task GetPrompt_WithoutAuthFilters_ThrowsInvalidOperationException() + public async Task GetPrompt_WithoutAuthFilters_ReturnsForbidden() { + // Without AddAuthorizationFilters(), getting an [Authorize] prompt returns HTTP 403 (incremental consent model). await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "authorized_prompt", new Dictionary { ["message"] = "test" }, cancellationToken: TestContext.Current.CancellationToken)); - Assert.Equal("Request failed (remote): An error occurred.", exception.Message); - Assert.Contains(MockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Warning && - log.Exception is InvalidOperationException && - log.Exception.Message.Contains("Authorization filter was not invoked for prompts/get operation") && - log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + Assert.Equal(System.Net.HttpStatusCode.Forbidden, exception.StatusCode); } [Fact] - public async Task ListResources_WithoutAuthFilters_ThrowsInvalidOperationException() + public async Task ListResources_WithoutAuthFilters_ReturnsAllResources() { + // Without AddAuthorizationFilters(), all resources are visible in listings (incremental consent model). await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => - await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken)); + var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal("Request failed (remote): An error occurred.", exception.Message); - Assert.Contains(MockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Warning && - log.Exception is InvalidOperationException && - log.Exception.Message.Contains("Authorization filter was not invoked for resources/list operation") && - log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + // All resources should be visible regardless of [Authorize] attributes when AddAuthorizationFilters() is not called. + Assert.Equal(2, resources.Count); + var uris = resources.Select(r => r.Uri).OrderBy(u => u).ToList(); + Assert.Equal(["resource://anonymous", "resource://authorized"], uris); } [Fact] - public async Task ReadResource_WithoutAuthFilters_ThrowsInvalidOperationException() + public async Task ReadResource_WithoutAuthFilters_ReturnsForbidden() { + // Without AddAuthorizationFilters(), reading an [Authorize] resource returns HTTP 403 (incremental consent model). await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => + var exception = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "resource://authorized", cancellationToken: TestContext.Current.CancellationToken)); - Assert.Equal("Request failed (remote): An error occurred.", exception.Message); - Assert.Contains(MockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Warning && - log.Exception is InvalidOperationException && - log.Exception.Message.Contains("Authorization filter was not invoked for resources/read operation") && - log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + Assert.Equal(System.Net.HttpStatusCode.Forbidden, exception.StatusCode); } [Fact] - public async Task ListResourceTemplates_WithoutAuthFilters_ThrowsInvalidOperationException() + public async Task ListResourceTemplates_WithoutAuthFilters_ReturnsAllTemplates() { + // Without AddAuthorizationFilters(), all resource templates are visible in listings (incremental consent model). await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); var client = await ConnectAsync(); - var exception = await Assert.ThrowsAsync(async () => - await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken)); - - Assert.Equal("Request failed (remote): An error occurred.", exception.Message); - Assert.Contains(MockLoggerProvider.LogMessages, log => - log.LogLevel == LogLevel.Warning && - log.Exception is InvalidOperationException && - log.Exception.Message.Contains("Authorization filter was not invoked for resources/templates/list operation") && - log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + var templates = await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken); + + // All resource templates should be visible regardless of [Authorize] attributes when AddAuthorizationFilters() is not called. + Assert.Single(templates); + Assert.Equal("resource://authorized/{id}", templates[0].UriTemplate); } [Fact] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/IncrementalConsentTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/IncrementalConsentTests.cs new file mode 100644 index 000000000..bed7bd10b --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/IncrementalConsentTests.cs @@ -0,0 +1,377 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Net; +using System.Net.Http; +using System.Security.Claims; +using System.Text; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for built-in incremental scope consent (SEP-835) support. +/// When AddAuthorizationFilters() is NOT called, the HTTP transport performs a pre-flight +/// authorization check that returns HTTP 403 with WWW-Authenticate: Bearer error="insufficient_scope" +/// to trigger client re-authentication with broader scopes. +/// +public class IncrementalConsentTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + private const string InitializeJson = """ + { + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { "name": "test", "version": "0.1" } + } + } + """; + + private async Task ConnectAsync() + { + var transport = new HttpClientTransport(new HttpClientTransportOptions + { + Endpoint = new Uri("http://localhost:5000"), + }, HttpClient, LoggerFactory); + + return await McpClient.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken, loggerFactory: LoggerFactory); + } + + // ------------------------- + // Listings: all primitives visible + // ------------------------- + + [Fact] + public async Task ListTools_WithoutAuthFilters_ReturnsAllToolsIncludingAuthorized() + { + await using var app = await StartServerAsync(builder => builder.WithTools()); + var client = await ConnectAsync(); + + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // All tools should be visible — [Authorize] does NOT hide tools when AddAuthorizationFilters() is not called. + Assert.Equal(2, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["public_tool", "scoped_tool"], toolNames); + } + + [Fact] + public async Task ListPrompts_WithoutAuthFilters_ReturnsAllPromptsIncludingAuthorized() + { + await using var app = await StartServerAsync(builder => builder.WithPrompts()); + var client = await ConnectAsync(); + + var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(2, prompts.Count); + var promptNames = prompts.Select(p => p.Name).OrderBy(n => n).ToList(); + Assert.Equal(["public_prompt", "scoped_prompt"], promptNames); + } + + [Fact] + public async Task ListResources_WithoutAuthFilters_ReturnsAllResourcesIncludingAuthorized() + { + await using var app = await StartServerAsync(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(2, resources.Count); + var uris = resources.Select(r => r.Uri).OrderBy(u => u).ToList(); + Assert.Equal(["resource://public", "resource://scoped"], uris); + } + + [Fact] + public async Task ListResourceTemplates_WithoutAuthFilters_ReturnsAllTemplatesIncludingAuthorized() + { + await using var app = await StartServerAsync(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var templates = await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(templates); + Assert.Equal("resource://scoped/{id}", templates[0].UriTemplate); + } + + // ------------------------- + // Invocations: unauthorized → HTTP 403 with WWW-Authenticate header + // ------------------------- + + [Fact] + public async Task CallTool_Unauthorized_Returns403WithInsufficientScopeHeader() + { + await using var app = await StartServerAsync(builder => builder.WithTools()); + + var sessionId = await InitializeSessionAsync(); + + var callToolJson = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { "name": "scoped_tool", "arguments": { "message": "test" } } + } + """; + + using var response = await PostJsonRpcAsync(callToolJson, sessionId); + + Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); + + // Verify WWW-Authenticate header contains the insufficient_scope error. + var wwwAuth = response.Headers.WwwAuthenticate.ToString(); + Assert.Contains("Bearer", wwwAuth); + Assert.Contains("insufficient_scope", wwwAuth); + Assert.Contains("read_data", wwwAuth); // scope from [Authorize(Roles = "read_data")] + Assert.Contains("resource_metadata", wwwAuth); + } + + [Fact] + public async Task GetPrompt_Unauthorized_Returns403WithInsufficientScopeHeader() + { + await using var app = await StartServerAsync(builder => builder.WithPrompts()); + + var sessionId = await InitializeSessionAsync(); + + var getPromptJson = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/get", + "params": { "name": "scoped_prompt", "arguments": { "message": "test" } } + } + """; + + using var response = await PostJsonRpcAsync(getPromptJson, sessionId); + + Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); + var wwwAuth = response.Headers.WwwAuthenticate.ToString(); + Assert.Contains("insufficient_scope", wwwAuth); + } + + [Fact] + public async Task ReadResource_Unauthorized_Returns403WithInsufficientScopeHeader() + { + await using var app = await StartServerAsync(builder => builder.WithResources()); + + var sessionId = await InitializeSessionAsync(); + + var readResourceJson = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": { "uri": "resource://scoped" } + } + """; + + using var response = await PostJsonRpcAsync(readResourceJson, sessionId); + + Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); + var wwwAuth = response.Headers.WwwAuthenticate.ToString(); + Assert.Contains("insufficient_scope", wwwAuth); + } + + // ------------------------- + // Authorized user: invocation succeeds + // ------------------------- + + [Fact] + public async Task CallTool_AuthorizedUser_Succeeds() + { + await using var app = await StartServerAsync(builder => builder.WithTools(), userName: "authorized-user", roles: ["read_data"]); + var client = await ConnectAsync(); + + var result = await client.CallToolAsync( + "scoped_tool", + new Dictionary { ["message"] = "hello" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Scoped: hello", content.Text); + } + + [Fact] + public async Task CallTool_WrongRole_Returns403() + { + await using var app = await StartServerAsync(builder => builder.WithTools(), userName: "wrong-role-user", roles: ["wrong_scope"]); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + "scoped_tool", + new Dictionary { ["message"] = "hello" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal(HttpStatusCode.Forbidden, exception.StatusCode); + } + + [Fact] + public async Task CallTool_PublicTool_AlwaysSucceeds() + { + await using var app = await StartServerAsync(builder => builder.WithTools()); + var client = await ConnectAsync(); + + // Public tool (no [Authorize]) should succeed without authentication. + var result = await client.CallToolAsync( + "public_tool", + new Dictionary { ["message"] = "hello" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Public: hello", content.Text); + } + + // ------------------------- + // AddAuthorizationFilters behavior is unchanged + // ------------------------- + + [Fact] + public async Task ListTools_WithAuthFilters_FiltersUnauthorizedTools() + { + await using var app = await StartServerWithAuthFiltersAsync(builder => builder.WithTools()); + var client = await ConnectAsync(); + + // With AddAuthorizationFilters(), unauthorized tools are hidden from listings. + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(tools); + Assert.Equal("public_tool", tools[0].Name); + } + + [Fact] + public async Task CallTool_WithAuthFilters_ReturnsJsonRpcError() + { + await using var app = await StartServerWithAuthFiltersAsync(builder => builder.WithTools()); + var client = await ConnectAsync(); + + // With AddAuthorizationFilters(), unauthorized invocation returns a JSON-RPC error (not HTTP 403). + McpProtocolException? exception = null; + try + { + await client.CallToolAsync( + "scoped_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + } + catch (McpProtocolException ex) + { + exception = ex; + } + + Assert.NotNull(exception); + Assert.Contains("Access forbidden", exception.Message); + } + + // ------------------------- + // Helpers + // ------------------------- + + private async Task StartServerAsync(Action configure, string? userName = null, params string[] roles) + { + var mcpServerBuilder = Builder.Services.AddMcpServer().WithHttpTransport(); // No AddAuthorizationFilters() + configure(mcpServerBuilder); + + Builder.Services.AddAuthorization(); + + var app = Builder.Build(); + + if (userName is not null) + { + app.Use(next => async context => + { + context.User = CreateUser(userName, roles); + await next(context); + }); + } + + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private async Task StartServerWithAuthFiltersAsync(Action configure) + { + var mcpServerBuilder = Builder.Services.AddMcpServer().WithHttpTransport().AddAuthorizationFilters(); + configure(mcpServerBuilder); + + Builder.Services.AddAuthorization(); + + var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private async Task InitializeSessionAsync() + { + using var response = await PostJsonRpcAsync(InitializeJson, sessionId: null); + Assert.True(response.IsSuccessStatusCode, $"Initialize failed with {response.StatusCode}"); + return response.Headers.TryGetValues("Mcp-Session-Id", out var ids) ? ids.FirstOrDefault() : null; + } + + private async Task PostJsonRpcAsync(string json, string? sessionId) + { + using var request = new HttpRequestMessage(HttpMethod.Post, "http://localhost:5000/") + { + Content = new StringContent(json, Encoding.UTF8, "application/json"), + }; + request.Headers.Add("Accept", "application/json, text/event-stream"); + if (sessionId is not null) + { + request.Headers.Add("Mcp-Session-Id", sessionId); + } + + return await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + } + + private static ClaimsPrincipal CreateUser(string name, params string[] roles) + => new(new ClaimsIdentity( + [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name), .. roles.Select(role => new Claim("role", role))], + "TestAuthType", "name", "role")); + + [McpServerToolType] + private class ScopedTools + { + [McpServerTool, Description("A public tool that requires no authorization.")] + public static string PublicTool(string message) => $"Public: {message}"; + + [McpServerTool, Description("A tool that requires the read_data scope.")] + [Authorize(Roles = "read_data")] + public static string ScopedTool(string message) => $"Scoped: {message}"; + } + + [McpServerPromptType] + private class ScopedPrompts + { + [McpServerPrompt, Description("A public prompt.")] + public static string PublicPrompt(string message) => $"Public prompt: {message}"; + + [McpServerPrompt, Description("A prompt that requires the read_data scope.")] + [Authorize(Roles = "read_data")] + public static string ScopedPrompt(string message) => $"Scoped prompt: {message}"; + } + + [McpServerResourceType] + private class ScopedResources + { + [McpServerResource(UriTemplate = "resource://public"), Description("A public resource.")] + public static string PublicResource() => "Public resource content"; + + [McpServerResource(UriTemplate = "resource://scoped"), Description("A scoped resource.")] + [Authorize(Roles = "read_data")] + public static string ScopedResource() => "Scoped resource content"; + + [McpServerResource(UriTemplate = "resource://scoped/{id}"), Description("A scoped resource template.")] + [Authorize(Roles = "read_data")] + public static string ScopedResourceTemplate(string id) => $"Scoped resource content: {id}"; + } +}