diff --git a/src/ModelContextProtocol.AspNetCore/DistributedCacheEventStreamStoreOptionsSetup.cs b/src/ModelContextProtocol.AspNetCore/DistributedCacheEventStreamStoreOptionsSetup.cs new file mode 100644 index 000000000..9433eea7e --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/DistributedCacheEventStreamStoreOptionsSetup.cs @@ -0,0 +1,17 @@ +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Configures by resolving +/// the from DI when not explicitly set. +/// +internal sealed class DistributedCacheEventStreamStoreOptionsSetup(IDistributedCache? cache = null) : IConfigureOptions +{ + public void Configure(DistributedCacheEventStreamStoreOptions options) + { + options.Cache ??= cache; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/DistributedCacheEventStreamStoreOptionsValidator.cs b/src/ModelContextProtocol.AspNetCore/DistributedCacheEventStreamStoreOptionsValidator.cs new file mode 100644 index 000000000..1b4786163 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/DistributedCacheEventStreamStoreOptionsValidator.cs @@ -0,0 +1,25 @@ +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Validates that is set. +/// +internal sealed class DistributedCacheEventStreamStoreOptionsValidator : IValidateOptions +{ + public ValidateOptionsResult Validate(string? name, DistributedCacheEventStreamStoreOptions options) + { + if (options.Cache is null) + { + return ValidateOptionsResult.Fail( + $"The '{nameof(DistributedCacheEventStreamStoreOptions)}.{nameof(DistributedCacheEventStreamStoreOptions.Cache)}' property must be set. " + + $"Register an {nameof(IDistributedCache)} in DI or set the {nameof(DistributedCacheEventStreamStoreOptions.Cache)} property " + + $"in the '{nameof(HttpMcpServerBuilderExtensions.WithDistributedCacheEventStreamStore)}' configure callback."); + } + + return ValidateOptionsResult.Success; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 3f8043808..556aac1e1 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -1,4 +1,5 @@ using Microsoft.AspNetCore.Authorization; +using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; using ModelContextProtocol.AspNetCore; @@ -33,6 +34,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.AddHostedService(); builder.Services.TryAddEnumerable(ServiceDescriptor.Transient, AuthorizationFilterSetup>()); + builder.Services.TryAddEnumerable(ServiceDescriptor.Transient, HttpServerTransportOptionsSetup>()); if (configureOptions is not null) { @@ -64,4 +66,37 @@ public static IMcpServerBuilder AddAuthorizationFilters(this IMcpServerBuilder b return builder; } + + /// + /// Registers a as the for SSE resumability. + /// + /// The builder instance. + /// An optional action to configure . + /// The builder provided in . + /// is . + /// + /// + /// An implementation must be registered in the service collection before calling this method. + /// The registered cache is automatically assigned to . + /// + /// + /// To use a specific instance instead of the one registered in DI, + /// set the property in the callback. + /// + /// + public static IMcpServerBuilder WithDistributedCacheEventStreamStore(this IMcpServerBuilder builder, Action? configureOptions = null) + { + ArgumentNullException.ThrowIfNull(builder); + + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton, DistributedCacheEventStreamStoreOptionsSetup>()); + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton, DistributedCacheEventStreamStoreOptionsValidator>()); + builder.Services.AddSingleton(); + + if (configureOptions is not null) + { + builder.Services.Configure(configureOptions); + } + + return builder; + } } diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index b1391e714..ce57af4b6 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -55,9 +55,29 @@ public class HttpServerTransportOptions /// Replay missed events when a client reconnects with a Last-Event-ID header /// Send priming events to establish resumability before any actual messages /// + /// + /// This can be set directly, or an can be registered in DI. + /// If this property is not set, the server will attempt to resolve an from DI. + /// /// public ISseEventStreamStore? EventStreamStore { get; set; } + /// + /// Gets or sets the session migration handler for cross-instance session migration. + /// + /// + /// + /// When configured, the server will support session migration between instances. + /// If a request arrives with a session ID that is not found locally, the handler + /// is consulted to determine if the session can be migrated from another instance. + /// + /// + /// This can be set directly, or an can be registered in DI. + /// If this property is not set, the server will attempt to resolve an from DI. + /// + /// + public ISessionMigrationHandler? SessionMigrationHandler { get; set; } + /// /// Gets or sets a value that indicates whether the server uses a single execution context for the entire session. /// diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptionsSetup.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptionsSetup.cs new file mode 100644 index 000000000..b4ce545f8 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptionsSetup.cs @@ -0,0 +1,18 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Post-configures by resolving services from DI +/// when they haven't been explicitly set on the options. +/// +internal sealed class HttpServerTransportOptionsSetup(IServiceProvider serviceProvider) : IConfigureOptions +{ + public void Configure(HttpServerTransportOptions options) + { + options.EventStreamStore ??= serviceProvider.GetService(); + options.SessionMigrationHandler ??= serviceProvider.GetService(); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index b3a51957b..b02333ae8 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -21,8 +21,7 @@ internal sealed class StreamableHttpHandler( StatefulSessionManager sessionManager, IHostApplicationLifetime hostApplicationLifetime, IServiceProvider applicationServices, - ILoggerFactory loggerFactory, - ISessionMigrationHandler? sessionMigrationHandler = null) + ILoggerFactory loggerFactory) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; private const string McpProtocolVersionHeaderName = "MCP-Protocol-Version"; @@ -255,7 +254,7 @@ await WriteJsonRpcErrorAsync(context, private async ValueTask TryMigrateSessionAsync(HttpContext context, string sessionId) { - if (sessionMigrationHandler is not { } handler) + if (HttpServerTransportOptions.SessionMigrationHandler is not { } handler) { return null; } @@ -336,7 +335,7 @@ private async ValueTask StartNewSessionAsync(HttpContext SessionId = sessionId, FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, EventStreamStore = HttpServerTransportOptions.EventStreamStore, - OnSessionInitialized = sessionMigrationHandler is { } handler + OnSessionInitialized = HttpServerTransportOptions.SessionMigrationHandler is { } handler ? (initParams, ct) => handler.OnSessionInitializedAsync(context, sessionId, initParams, ct) : null, }; diff --git a/src/ModelContextProtocol/McpServerOptionsSetup.cs b/src/ModelContextProtocol/McpServerOptionsSetup.cs index 2042ead53..5977fae7e 100644 --- a/src/ModelContextProtocol/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/McpServerOptionsSetup.cs @@ -4,15 +4,17 @@ namespace ModelContextProtocol; /// -/// Configures the McpServerOptions using addition services from DI. +/// Configures the McpServerOptions using additional services from DI. /// /// The individually registered tools. /// The individually registered prompts. /// The individually registered resources. +/// The optional task store registered in DI. internal sealed class McpServerOptionsSetup( IEnumerable serverTools, IEnumerable serverPrompts, - IEnumerable serverResources) : IConfigureOptions + IEnumerable serverResources, + IMcpTaskStore? taskStore = null) : IConfigureOptions { /// /// Configures the given McpServerOptions instance by setting server information @@ -23,6 +25,8 @@ public void Configure(McpServerOptions options) { Throw.IfNull(options); + options.TaskStore ??= taskStore; + // Collect all of the provided tools into a tools collection. If the options already has // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants diff --git a/src/ModelContextProtocol/McpServerServiceCollectionExtensions.cs b/src/ModelContextProtocol/McpServerServiceCollectionExtensions.cs index 8ead2ce08..7cc893ba1 100644 --- a/src/ModelContextProtocol/McpServerServiceCollectionExtensions.cs +++ b/src/ModelContextProtocol/McpServerServiceCollectionExtensions.cs @@ -26,17 +26,6 @@ public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, A services.Configure(configureOptions); } - // Register IMcpTaskStore from options if not already registered. - // This allows users to either: - // 1. Register IMcpTaskStore directly in DI (takes precedence) - // 2. Set options.TaskStore in the configuration callback (used as fallback) - // If neither is done, resolving IMcpTaskStore will throw. - services.TryAddSingleton(sp => - { - var options = sp.GetRequiredService>().Value; - return options.TaskStore ?? throw new InvalidOperationException("No IMcpTaskStore has been configured. Either register an IMcpTaskStore in the service collection or set McpServerOptions.TaskStore when configuring the MCP server."); - }); - return new DefaultMcpServerBuilder(services); } } diff --git a/src/ModelContextProtocol/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol/Server/DistributedCacheEventStreamStore.cs index 4a71b9448..d0a315666 100644 --- a/src/ModelContextProtocol/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol/Server/DistributedCacheEventStreamStore.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; @@ -31,14 +32,16 @@ public sealed partial class DistributedCacheEventStreamStore : ISseEventStreamSt /// /// Initializes a new instance of the class. /// - /// The distributed cache to use for storage. - /// Optional configuration options for the store. + /// Configuration options for the store, including the to use. /// Optional logger for diagnostic output. - public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null, ILogger? logger = null) + public DistributedCacheEventStreamStore(IOptions options, ILogger? logger = null) { - Throw.IfNull(cache); - _cache = cache; - _options = options ?? new(); + Throw.IfNull(options); + + var optionsValue = options.Value; + _cache = optionsValue.Cache ?? throw new InvalidOperationException( + $"The '{nameof(DistributedCacheEventStreamStoreOptions)}.{nameof(DistributedCacheEventStreamStoreOptions.Cache)}' property must be set."); + _options = optionsValue; _logger = logger ?? NullLogger.Instance; } diff --git a/src/ModelContextProtocol/Server/DistributedCacheEventStreamStoreOptions.cs b/src/ModelContextProtocol/Server/DistributedCacheEventStreamStoreOptions.cs index e1542ca62..f434e12c3 100644 --- a/src/ModelContextProtocol/Server/DistributedCacheEventStreamStoreOptions.cs +++ b/src/ModelContextProtocol/Server/DistributedCacheEventStreamStoreOptions.cs @@ -1,3 +1,5 @@ +using Microsoft.Extensions.Caching.Distributed; + namespace ModelContextProtocol.Server; /// @@ -5,6 +7,16 @@ namespace ModelContextProtocol.Server; /// public sealed class DistributedCacheEventStreamStoreOptions { + /// + /// Gets or sets the to use for event storage. + /// + /// + /// When using dependency injection with WithDistributedCacheEventStreamStore(), this is + /// automatically populated from the registered in DI. + /// Set this property explicitly to use a specific cache instance. + /// + public IDistributedCache? Cache { get; set; } + /// /// Gets or sets the sliding expiration for individual events in the cache. /// diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/DistributedCacheResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/DistributedCacheResumabilityIntegrationTests.cs index 8ec31ea73..66e06c0f7 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/DistributedCacheResumabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/DistributedCacheResumabilityIntegrationTests.cs @@ -22,17 +22,18 @@ namespace ModelContextProtocol.AspNetCore.Tests; /// public class DistributedCacheResumabilityIntegrationTests(ITestOutputHelper testOutputHelper) : ResumabilityIntegrationTestsBase(testOutputHelper) { - private MemoryDistributedCache? _cache; - /// protected override ValueTask CreateEventStreamStoreAsync() { // Create a new in-memory distributed cache for each test - _cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + var cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); // Configure the store with shorter expiration times suitable for testing var options = new DistributedCacheEventStreamStoreOptions { + // Use the in-memory distributed cache + Cache = cache, + // Use shorter polling interval for faster test execution StreamReaderPollingInterval = TimeSpan.FromMilliseconds(50), @@ -43,7 +44,7 @@ protected override ValueTask CreateEventStreamStoreAsync() MetadataAbsoluteExpiration = TimeSpan.FromMinutes(10), }; - var store = new DistributedCacheEventStreamStore(_cache, options, LoggerFactory.CreateLogger()); + var store = new DistributedCacheEventStreamStore(Options.Create(options), LoggerFactory.CreateLogger()); return new ValueTask(store); } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpMcpServerBuilderExtensionsTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpMcpServerBuilderExtensionsTests.cs new file mode 100644 index 000000000..cc6ff0b13 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpMcpServerBuilderExtensionsTests.cs @@ -0,0 +1,197 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class HttpMcpServerBuilderExtensionsTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + [Fact] + public void WithDistributedCacheEventStreamStore_RegistersStoreInDI() + { + Builder.Services.AddDistributedMemoryCache(); + Builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithDistributedCacheEventStreamStore(); + + using var app = Builder.Build(); + + var store = app.Services.GetService(); + Assert.IsType(store); + } + + [Fact] + public void WithDistributedCacheEventStreamStore_ConfigureCallbackIsInvoked() + { + DistributedCacheEventStreamStoreOptions? capturedOptions = null; + + Builder.Services.AddDistributedMemoryCache(); + Builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithDistributedCacheEventStreamStore(options => capturedOptions = options); + + using var app = Builder.Build(); + + // Force options resolution to trigger the configure callback. + _ = app.Services.GetRequiredService>().Value; + + Assert.NotNull(capturedOptions); + } + + [Fact] + public void WithDistributedCacheEventStreamStore_WorksWithoutDICache_WhenCacheSetViaCallback() + { + var explicitCache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + + Builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithDistributedCacheEventStreamStore(options => options.Cache = explicitCache); + + using var app = Builder.Build(); + + var store = app.Services.GetService(); + Assert.IsType(store); + } + + [Fact] + public void WithDistributedCacheEventStreamStore_ThrowsOptionsValidationException_WhenNoCacheConfigured() + { + Builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithDistributedCacheEventStreamStore(); + + using var app = Builder.Build(); + + var ex = Assert.Throws( + () => app.Services.GetRequiredService()); + Assert.StartsWith($"The '{nameof(DistributedCacheEventStreamStoreOptions)}.{nameof(DistributedCacheEventStreamStoreOptions.Cache)}'", ex.Message); + } + + [Fact] + public void EventStreamStore_IsPopulatedFromDI_ViaPostConfigure() + { + Builder.Services.AddDistributedMemoryCache(); + Builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithDistributedCacheEventStreamStore(); + + using var app = Builder.Build(); + + var options = app.Services.GetRequiredService>().Value; + Assert.IsType(options.EventStreamStore); + } + + [Fact] + public void EventStreamStore_ExplicitOption_TakesPrecedenceOverDI() + { + var explicitStore = new TestSseEventStreamStore(); + + Builder.Services.AddDistributedMemoryCache(); + Builder.Services + .AddMcpServer() + .WithHttpTransport(options => options.EventStreamStore = explicitStore) + .WithDistributedCacheEventStreamStore(); + + using var app = Builder.Build(); + + var options = app.Services.GetRequiredService>().Value; + Assert.Same(explicitStore, options.EventStreamStore); + } + + [Fact] + public void EventStreamStore_RemainsNull_WhenNothingIsRegistered() + { + Builder.Services + .AddMcpServer() + .WithHttpTransport(); + + using var app = Builder.Build(); + + var options = app.Services.GetRequiredService>().Value; + Assert.Null(options.EventStreamStore); + } + + [Fact] + public void EventStreamStore_CanBeOverriddenToNull_AfterDIRegistration() + { + Builder.Services.AddDistributedMemoryCache(); + Builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithDistributedCacheEventStreamStore(); + + Builder.Services.Configure(options => options.EventStreamStore = null); + + using var app = Builder.Build(); + + var options = app.Services.GetRequiredService>().Value; + Assert.Null(options.EventStreamStore); + } + + [Fact] + public void SessionMigrationHandler_IsPopulatedFromDI_ViaPostConfigure() + { + var handler = new StubSessionMigrationHandler(); + + Builder.Services.AddSingleton(handler); + Builder.Services + .AddMcpServer() + .WithHttpTransport(); + + using var app = Builder.Build(); + + var options = app.Services.GetRequiredService>().Value; + Assert.Same(handler, options.SessionMigrationHandler); + } + + [Fact] + public void SessionMigrationHandler_ExplicitOption_TakesPrecedenceOverDI() + { + var diHandler = new StubSessionMigrationHandler(); + var explicitHandler = new StubSessionMigrationHandler(); + + Builder.Services.AddSingleton(diHandler); + Builder.Services + .AddMcpServer() + .WithHttpTransport(options => options.SessionMigrationHandler = explicitHandler); + + using var app = Builder.Build(); + + var options = app.Services.GetRequiredService>().Value; + Assert.Same(explicitHandler, options.SessionMigrationHandler); + } + + [Fact] + public void SessionMigrationHandler_RemainsNull_WhenNothingIsRegistered() + { + Builder.Services + .AddMcpServer() + .WithHttpTransport(); + + using var app = Builder.Build(); + + var options = app.Services.GetRequiredService>().Value; + Assert.Null(options.SessionMigrationHandler); + } + + private sealed class StubSessionMigrationHandler : ISessionMigrationHandler + { + public ValueTask AllowSessionMigrationAsync( + HttpContext context, string sessionId, CancellationToken cancellationToken = default) + => new((InitializeRequestParams?)null); + + public ValueTask OnSessionInitializedAsync( + HttpContext context, string sessionId, InitializeRequestParams initializeParams, CancellationToken cancellationToken = default) + => default; + } +} diff --git a/tests/ModelContextProtocol.ConformanceServer/Program.cs b/tests/ModelContextProtocol.ConformanceServer/Program.cs index f216a34a2..94c62727b 100644 --- a/tests/ModelContextProtocol.ConformanceServer/Program.cs +++ b/tests/ModelContextProtocol.ConformanceServer/Program.cs @@ -1,9 +1,6 @@ using ConformanceServer.Prompts; using ConformanceServer.Resources; using ConformanceServer.Tools; -using Microsoft.Extensions.Caching.Distributed; -using Microsoft.Extensions.Caching.Memory; -using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using System.Collections.Concurrent; @@ -28,14 +25,11 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide // because .NET does not have a built-in concurrent HashSet ConcurrentDictionary> subscriptions = new(); + builder.Services.AddDistributedMemoryCache(); builder.Services .AddMcpServer() - .WithHttpTransport(options => - { - // Enable resumability for SSE polling conformance test - options.EventStreamStore = new DistributedCacheEventStreamStore( - new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); - }) + .WithHttpTransport() + .WithDistributedCacheEventStreamStore() .WithTools() .WithTools([ConformanceTools.CreateJsonSchema202012Tool()]) .WithRequestFilters(filters => filters.AddCallToolFilter(next => async (request, cancellationToken) => diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs index 6f770bf5d..40165d58c 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs @@ -283,4 +283,57 @@ public void Configure_WithCompleteHandler_CreatesCompletionsCapability() Assert.NotNull(options.Capabilities?.Completions); } #endregion + + #region TaskStore Tests + [Fact] + public void TaskStore_IsPopulatedFromDI_WhenNotExplicitlySet() + { + var services = new ServiceCollection(); + services.AddMcpServer(); + services.AddSingleton(); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.IsType(options.TaskStore); + } + + [Fact] + public void TaskStore_ExplicitOption_TakesPrecedenceOverDI() + { + var explicitStore = new InMemoryMcpTaskStore(); + + var services = new ServiceCollection(); + services.AddMcpServer(options => options.TaskStore = explicitStore); + services.AddSingleton(); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.Same(explicitStore, options.TaskStore); + } + + [Fact] + public void TaskStore_RemainsNull_WhenNothingIsRegistered() + { + var services = new ServiceCollection(); + services.AddMcpServer(); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.Null(options.TaskStore); + } + + [Fact] + public void TaskStore_CanBeOverriddenToNull_AfterDIRegistration() + { + var services = new ServiceCollection(); + services.AddMcpServer(); + services.AddSingleton(); + + services.Configure(options => options.TaskStore = null); + + var options = services.BuildServiceProvider().GetRequiredService>().Value; + + Assert.Null(options.TaskStore); + } + #endregion } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index 1ec3cac99..0983e6ad9 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -21,10 +21,25 @@ private static IDistributedCache CreateMemoryCache() return new MemoryDistributedCache(options); } + private static DistributedCacheEventStreamStore CreateStore(IDistributedCache? cache = null, DistributedCacheEventStreamStoreOptions? storeOptions = null) + { + storeOptions ??= new(); + storeOptions.Cache ??= cache ?? CreateMemoryCache(); + return new DistributedCacheEventStreamStore(Options.Create(storeOptions)); + } + + [Fact] + public void Constructor_ThrowsArgumentNullException_WhenOptionsIsNull() + { + Assert.Throws("options", () => new DistributedCacheEventStreamStore(null!)); + } + [Fact] - public void Constructor_ThrowsArgumentNullException_WhenCacheIsNull() + public void Constructor_ThrowsInvalidOperationException_WhenCacheIsNull() { - Assert.Throws("cache", () => new DistributedCacheEventStreamStore(null!)); + var options = Options.Create(new DistributedCacheEventStreamStoreOptions()); + var ex = Assert.Throws(() => new DistributedCacheEventStreamStore(options)); + Assert.StartsWith($"The '{nameof(DistributedCacheEventStreamStoreOptions)}.{nameof(DistributedCacheEventStreamStoreOptions.Cache)}'", ex.Message); } [Fact] @@ -32,7 +47,7 @@ public async Task CreateStreamAsync_ThrowsArgumentNullException_WhenOptionsIsNul { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); // Act & Assert await Assert.ThrowsAsync("options", @@ -44,7 +59,7 @@ public async Task WriteEventAsync_AssignsUniqueEventId_WhenItemHasNoEventId() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -67,7 +82,7 @@ public async Task WriteEventAsync_SkipsAssigningEventId_WhenItemAlreadyHasEventI { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -90,7 +105,7 @@ public async Task WriteEventAsync_PreservesDataProperty_InReturnedItem() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -113,7 +128,7 @@ public async Task WriteEventAsync_PreservesEventTypeProperty_InReturnedItem() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -135,7 +150,7 @@ public async Task WriteEventAsync_PreservesReconnectionIntervalProperty_InStored { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -187,7 +202,7 @@ public async Task WriteEventAsync_HandlesNullReconnectionInterval_InStoredEvent( { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -224,7 +239,7 @@ public async Task WriteEventAsync_HandlesNullData_AssignsEventIdAndStoresEvent() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -254,7 +269,7 @@ public async Task WriteEventAsync_StoresEventWithCorrectSlidingExpiration() { EventSlidingExpiration = TimeSpan.FromMinutes(15) }; - var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var store = CreateStore(mockCache, customOptions); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -282,7 +297,7 @@ public async Task WriteEventAsync_StoresEventWithCorrectAbsoluteExpiration() { EventAbsoluteExpiration = TimeSpan.FromHours(3) }; - var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var store = CreateStore(mockCache, customOptions); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -306,7 +321,7 @@ public async Task WriteEventAsync_UpdatesStreamMetadata_AfterEachWrite() { // Arrange var mockCache = new TestDistributedCache(); - var store = new DistributedCacheEventStreamStore(mockCache); + var store = CreateStore(mockCache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -328,7 +343,7 @@ public async Task SetModeAsync_PersistsModeChangeToMetadata() { // Arrange var mockCache = new TestDistributedCache(); - var store = new DistributedCacheEventStreamStore(mockCache); + var store = CreateStore(mockCache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -354,7 +369,7 @@ public async Task SetModeAsync_ModeChangeReflectedInReader() { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(10) }; - var store = new DistributedCacheEventStreamStore(cache, customOptions); + var store = CreateStore(cache, customOptions); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -390,7 +405,7 @@ public async Task DisposeAsync_MarksStreamAsCompleted() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -425,7 +440,7 @@ public async Task DisposeAsync_IsIdempotent() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -447,7 +462,7 @@ public async Task DisposeAsync_UpdatesMetadata_WithIsCompletedFlag() { // Arrange var mockCache = new TestDistributedCache(); - var store = new DistributedCacheEventStreamStore(mockCache); + var store = CreateStore(mockCache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -469,7 +484,7 @@ public async Task GetStreamReaderAsync_ThrowsArgumentNullException_WhenLastEvent { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); // Act & Assert await Assert.ThrowsAsync("lastEventId", @@ -481,7 +496,7 @@ public async Task GetStreamReaderAsync_ReturnsNull_WhenEventIdIsUnparseable() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); // Act - Try various invalid event ID formats var result1 = await store.GetStreamReaderAsync("invalid-format", CancellationToken); @@ -499,7 +514,7 @@ public async Task GetStreamReaderAsync_ReturnsNull_WhenStreamMetadataDoesNotExis { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); // Create a valid-looking event ID for a stream that doesn't exist var fakeEventId = DistributedCacheEventIdFormatter.Format("nonexistent-session", "nonexistent-stream", 1); @@ -516,7 +531,7 @@ public async Task GetStreamReaderAsync_ReturnsReaderWithCorrectSessionIdAndStrea { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "my-session", @@ -542,7 +557,7 @@ public async Task ReadEventsAsync_ReturnsEventsInOrder() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -579,7 +594,7 @@ public async Task ReadEventsAsync_ReturnsEmpty_WhenNoNewEventsExist() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -610,7 +625,7 @@ public async Task ReadEventsAsync_PreservesCorrectDataEventTypeAndEventId() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -648,7 +663,7 @@ public async Task ReadEventsAsync_HandlesNullData() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -681,7 +696,7 @@ public async Task ReadEventsAsync_InPollingMode_CompletesImmediatelyAfterReturni { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -717,7 +732,7 @@ public async Task ReadEventsAsync_InPollingMode_ReturnsOnlyEventsAfterLastEventI { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -751,7 +766,7 @@ public async Task ReadEventsAsync_InPollingMode_ReturnsEmptyIfNoNewEvents() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -780,7 +795,7 @@ public async Task ReadEventsAsync_InPollingMode_DoesNotWaitForNewEvents() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -812,7 +827,7 @@ public async Task ReadEventsAsync_InStreamingMode_WaitsForNewEvents() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + var store = CreateStore(cache, new DistributedCacheEventStreamStoreOptions { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(50) }); @@ -869,7 +884,7 @@ public async Task ReadEventsAsync_InStreamingMode_YieldsNewlyWrittenEvents() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + var store = CreateStore(cache, new DistributedCacheEventStreamStoreOptions { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(50) }); @@ -927,7 +942,7 @@ public async Task ReadEventsAsync_InStreamingMode_CompletesWhenStreamIsDisposed( { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + var store = CreateStore(cache, new DistributedCacheEventStreamStoreOptions { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(50) }); @@ -965,7 +980,7 @@ public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + var store = CreateStore(cache, new DistributedCacheEventStreamStoreOptions { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(50) }); @@ -1029,7 +1044,7 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + var store = CreateStore(cache, new DistributedCacheEventStreamStoreOptions { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(50) }); @@ -1074,7 +1089,7 @@ public async Task ReadEventsAsync_PollingModeReturnsEventsThenCompletes() { // Arrange - Start in default mode, write some events, switch to polling, reader should return remaining events var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + var store = CreateStore(cache, new DistributedCacheEventStreamStoreOptions { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(50) }); @@ -1120,7 +1135,7 @@ public async Task MultipleStreams_AreIsolated_EventsDoNotLeakBetweenStreams() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); // Create two streams with different session/stream IDs var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions @@ -1178,7 +1193,7 @@ public async Task MultipleStreams_SameSession_DifferentStreamIds_AreIsolated() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); // Create two streams with same session but different stream IDs var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions @@ -1231,7 +1246,7 @@ public async Task EventIds_AreGloballyUnique_AcrossStreams() { // Arrange var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); + var store = CreateStore(cache); var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions { @@ -1267,7 +1282,7 @@ public async Task WriteEventAsync_UsesConfiguredSlidingExpiration() { EventSlidingExpiration = TimeSpan.FromMinutes(30) }; - var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var store = CreateStore(mockCache, customOptions); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -1295,7 +1310,7 @@ public async Task WriteEventAsync_UsesConfiguredAbsoluteExpiration() { EventAbsoluteExpiration = TimeSpan.FromHours(6) }; - var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var store = CreateStore(mockCache, customOptions); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -1325,7 +1340,7 @@ public async Task WriteEventAsync_UsesConfiguredMetadataExpiration() MetadataSlidingExpiration = TimeSpan.FromMinutes(45), MetadataAbsoluteExpiration = TimeSpan.FromHours(12) }; - var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var store = CreateStore(mockCache, customOptions); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", @@ -1366,7 +1381,7 @@ public async Task ReadEventsAsync_ThrowsMcpException_WhenMetadataExpires() { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(10) // Fast polling to detect the bug quickly }; - var store = new DistributedCacheEventStreamStore(trackingCache, customOptions); + var store = CreateStore(trackingCache, customOptions); // Create a stream and write an event var writer = await store.CreateStreamAsync(new SseEventStreamOptions @@ -1405,7 +1420,7 @@ public async Task ReadEventsAsync_ThrowsMcpException_WhenEventExpires() { // Arrange - Use a cache that allows us to simulate event expiration var trackingCache = new TestDistributedCache(); - var store = new DistributedCacheEventStreamStore(trackingCache); + var store = CreateStore(trackingCache); // Create a stream and write multiple events var writer = await store.CreateStreamAsync(new SseEventStreamOptions @@ -1450,7 +1465,7 @@ public async Task ReadEventsAsync_DoesNotReadMetadata_InPollingMode() { StreamReaderPollingInterval = TimeSpan.FromMilliseconds(10) }; - var store = new DistributedCacheEventStreamStore(trackingCache, customOptions); + var store = CreateStore(trackingCache, customOptions); // Create a stream in POLLING mode - this allows the reader to exit after reading available events var writer = await store.CreateStreamAsync(new SseEventStreamOptions