Skip to content

Commit 27a3196

Browse files
Add ISessionMigrationHandler (#1270)
1 parent 411cef6 commit 27a3196

File tree

9 files changed

+589
-26
lines changed

9 files changed

+589
-26
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using Microsoft.AspNetCore.Http;
2+
using ModelContextProtocol.Protocol;
3+
4+
namespace ModelContextProtocol.AspNetCore;
5+
6+
/// <summary>
7+
/// Provides hooks for persisting and restoring MCP session initialization data,
8+
/// enabling session migration across server instances.
9+
/// </summary>
10+
/// <remarks>
11+
/// <para>
12+
/// When an MCP server is horizontally scaled, stateful sessions are bound to a single process.
13+
/// If that process restarts or scales down, the session is lost. By implementing this interface
14+
/// and registering it with DI, you can persist the initialization handshake data and restore it
15+
/// when a client reconnects to a different server instance with its existing <c>Mcp-Session-Id</c>.
16+
/// </para>
17+
/// <para>
18+
/// This does <strong>not</strong> solve the session-affinity problem for in-flight server-to-client
19+
/// requests (such as sampling or elicitation). Responses to those requests must still be routed to
20+
/// the process that created the request. This interface only enables migration of idle sessions
21+
/// by persisting the data established during the initialization handshake.
22+
/// </para>
23+
/// </remarks>
24+
public interface ISessionMigrationHandler
25+
{
26+
/// <summary>
27+
/// Called after a session has been successfully initialized via the MCP initialization handshake.
28+
/// </summary>
29+
/// <remarks>
30+
/// Use this to persist the <paramref name="initializeParams"/> (which includes client capabilities,
31+
/// client info, and protocol version) to an external store so the session can be migrated to
32+
/// another server instance later via <see cref="AllowSessionMigrationAsync"/>.
33+
/// </remarks>
34+
/// <param name="context">The <see cref="HttpContext"/> for the initialization request.</param>
35+
/// <param name="sessionId">The unique identifier for the session.</param>
36+
/// <param name="initializeParams">The initialization parameters sent by the client during the handshake.</param>
37+
/// <param name="cancellationToken">A cancellation token.</param>
38+
/// <returns>A <see cref="ValueTask"/> representing the asynchronous operation.</returns>
39+
ValueTask OnSessionInitializedAsync(HttpContext context, string sessionId, InitializeRequestParams initializeParams, CancellationToken cancellationToken);
40+
41+
/// <summary>
42+
/// Called when a request arrives with an <c>Mcp-Session-Id</c> that the current server doesn't recognize.
43+
/// </summary>
44+
/// <remarks>
45+
/// <para>
46+
/// Return the original <see cref="InitializeRequestParams"/> to allow the session to be migrated
47+
/// to this server instance, or <see langword="null"/> to reject the request (returning a 404 to the client).
48+
/// </para>
49+
/// <para>
50+
/// Implementations should validate that the request is authorized, for example by checking
51+
/// <see cref="HttpContext.User"/>, to ensure the caller is permitted to migrate the session.
52+
/// </para>
53+
/// </remarks>
54+
/// <param name="context">The <see cref="HttpContext"/> for the request with the unrecognized session ID.</param>
55+
/// <param name="sessionId">The session ID from the request that was not found on this server.</param>
56+
/// <param name="cancellationToken">A cancellation token.</param>
57+
/// <returns>
58+
/// The original <see cref="InitializeRequestParams"/> if migration is allowed,
59+
/// or <see langword="null"/> to reject the request.
60+
/// </returns>
61+
ValueTask<InitializeRequestParams?> AllowSessionMigrationAsync(HttpContext context, string sessionId, CancellationToken cancellationToken);
62+
}

src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs

Lines changed: 111 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.Net.Http.Headers;
88
using ModelContextProtocol.Protocol;
99
using ModelContextProtocol.Server;
10+
using System.Collections.Concurrent;
1011
using System.Security.Claims;
1112
using System.Security.Cryptography;
1213
using System.Text.Json.Serialization.Metadata;
@@ -20,7 +21,8 @@ internal sealed class StreamableHttpHandler(
2021
StatefulSessionManager sessionManager,
2122
IHostApplicationLifetime hostApplicationLifetime,
2223
IServiceProvider applicationServices,
23-
ILoggerFactory loggerFactory)
24+
ILoggerFactory loggerFactory,
25+
ISessionMigrationHandler? sessionMigrationHandler = null)
2426
{
2527
private const string McpSessionIdHeaderName = "Mcp-Session-Id";
2628
private const string McpProtocolVersionHeaderName = "MCP-Protocol-Version";
@@ -41,6 +43,11 @@ internal sealed class StreamableHttpHandler(
4143
private static readonly JsonTypeInfo<JsonRpcMessage> s_messageTypeInfo = GetRequiredJsonTypeInfo<JsonRpcMessage>();
4244
private static readonly JsonTypeInfo<JsonRpcError> s_errorTypeInfo = GetRequiredJsonTypeInfo<JsonRpcError>();
4345

46+
private static bool AllowNewSessionForNonInitializeRequests { get; } =
47+
AppContext.TryGetSwitch("ModelContextProtocol.AspNetCore.AllowNewSessionForNonInitializeRequests", out var enabled) && enabled;
48+
49+
private readonly ConcurrentDictionary<string, SemaphoreSlim> _migrationLocks = new(StringComparer.Ordinal);
50+
4451
public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value;
4552

4653
public async Task HandlePostRequestAsync(HttpContext context)
@@ -64,14 +71,6 @@ await WriteJsonRpcErrorAsync(context,
6471
return;
6572
}
6673

67-
var session = await GetOrCreateSessionAsync(context);
68-
if (session is null)
69-
{
70-
return;
71-
}
72-
73-
await using var _ = await session.AcquireReferenceAsync(context.RequestAborted);
74-
7574
var message = await ReadJsonRpcMessageAsync(context);
7675
if (message is null)
7776
{
@@ -81,6 +80,14 @@ await WriteJsonRpcErrorAsync(context,
8180
return;
8281
}
8382

83+
var session = await GetOrCreateSessionAsync(context, message);
84+
if (session is null)
85+
{
86+
return;
87+
}
88+
89+
await using var _ = await session.AcquireReferenceAsync(context.RequestAborted);
90+
8491
InitializeSseResponse(context);
8592
var wroteResponse = await session.Transport.HandlePostRequestAsync(message, context.Response.Body, context.RequestAborted);
8693
if (!wroteResponse)
@@ -219,12 +226,18 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
219226

220227
if (!sessionManager.TryGetValue(sessionId, out var session))
221228
{
222-
// -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
223-
// One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
224-
// JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
225-
// https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields
226-
await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, -32001);
227-
return null;
229+
// Session not found locally. Attempt migration if a handler is registered.
230+
session = await TryMigrateSessionAsync(context, sessionId);
231+
232+
if (session is null)
233+
{
234+
// -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
235+
// One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
236+
// JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
237+
// https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields
238+
await WriteJsonRpcErrorAsync(context, "Session not found", StatusCodes.Status404NotFound, -32001);
239+
return null;
240+
}
228241
}
229242

230243
if (!session.HasSameUserId(context.User))
@@ -240,12 +253,61 @@ await WriteJsonRpcErrorAsync(context,
240253
return session;
241254
}
242255

243-
private async ValueTask<StreamableHttpSession?> GetOrCreateSessionAsync(HttpContext context)
256+
private async ValueTask<StreamableHttpSession?> TryMigrateSessionAsync(HttpContext context, string sessionId)
257+
{
258+
if (sessionMigrationHandler is not { } handler)
259+
{
260+
return null;
261+
}
262+
263+
var migrationLock = _migrationLocks.GetOrAdd(sessionId, static _ => new SemaphoreSlim(1, 1));
264+
await migrationLock.WaitAsync(context.RequestAborted);
265+
try
266+
{
267+
// Re-check after acquiring the lock - another thread may have already completed migration.
268+
if (sessionManager.TryGetValue(sessionId, out var session))
269+
{
270+
return session;
271+
}
272+
273+
var initParams = await handler.AllowSessionMigrationAsync(context, sessionId, context.RequestAborted);
274+
if (initParams is null)
275+
{
276+
return null;
277+
}
278+
279+
var migratedSession = await MigrateSessionAsync(context, sessionId, initParams);
280+
281+
// Register the session with the session manager while still holding the lock
282+
// so concurrent requests for the same session ID find it via sessionManager.TryGetValue.
283+
await migratedSession.EnsureStartedAsync(context.RequestAborted);
284+
285+
return migratedSession;
286+
}
287+
finally
288+
{
289+
migrationLock.Release();
290+
_migrationLocks.TryRemove(sessionId, out _);
291+
}
292+
}
293+
294+
private async ValueTask<StreamableHttpSession?> GetOrCreateSessionAsync(HttpContext context, JsonRpcMessage message)
244295
{
245296
var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString();
246297

247298
if (string.IsNullOrEmpty(sessionId))
248299
{
300+
// In stateful mode, only allow creating new sessions for initialize requests.
301+
// In stateless mode, every request is independent, so we always create a new session.
302+
if (!HttpServerTransportOptions.Stateless && !AllowNewSessionForNonInitializeRequests
303+
&& message is not JsonRpcRequest { Method: RequestMethods.Initialize })
304+
{
305+
await WriteJsonRpcErrorAsync(context,
306+
"Bad Request: A new session can only be created by an initialize request. Include a valid Mcp-Session-Id header for non-initialize requests.",
307+
StatusCodes.Status400BadRequest);
308+
return null;
309+
}
310+
249311
return await StartNewSessionAsync(context);
250312
}
251313
else if (HttpServerTransportOptions.Stateless)
@@ -274,7 +336,11 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
274336
SessionId = sessionId,
275337
FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
276338
EventStreamStore = HttpServerTransportOptions.EventStreamStore,
339+
OnSessionInitialized = sessionMigrationHandler is { } handler
340+
? (initParams, ct) => handler.OnSessionInitializedAsync(context, sessionId, initParams, ct)
341+
: null,
277342
};
343+
278344
context.Response.Headers[McpSessionIdHeaderName] = sessionId;
279345
}
280346
else
@@ -295,11 +361,12 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
295361
private async ValueTask<StreamableHttpSession> CreateSessionAsync(
296362
HttpContext context,
297363
StreamableHttpServerTransport transport,
298-
string sessionId)
364+
string sessionId,
365+
Action<McpServerOptions>? configureOptions = null)
299366
{
300367
var mcpServerServices = applicationServices;
301368
var mcpServerOptions = mcpServerOptionsSnapshot.Value;
302-
if (HttpServerTransportOptions.Stateless || HttpServerTransportOptions.ConfigureSessionOptions is not null)
369+
if (HttpServerTransportOptions.Stateless || HttpServerTransportOptions.ConfigureSessionOptions is not null || configureOptions is not null)
303370
{
304371
mcpServerOptions = mcpServerOptionsFactory.Create(Options.DefaultName);
305372

@@ -310,6 +377,8 @@ private async ValueTask<StreamableHttpSession> CreateSessionAsync(
310377
mcpServerOptions.ScopeRequests = false;
311378
}
312379

380+
configureOptions?.Invoke(mcpServerOptions);
381+
313382
if (HttpServerTransportOptions.ConfigureSessionOptions is { } configureSessionOptions)
314383
{
315384
await configureSessionOptions(context, mcpServerOptions, context.RequestAborted);
@@ -328,6 +397,30 @@ private async ValueTask<StreamableHttpSession> CreateSessionAsync(
328397
return session;
329398
}
330399

400+
private async ValueTask<StreamableHttpSession> MigrateSessionAsync(
401+
HttpContext context,
402+
string sessionId,
403+
InitializeRequestParams initializeParams)
404+
{
405+
var transport = new StreamableHttpServerTransport(loggerFactory)
406+
{
407+
SessionId = sessionId,
408+
FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
409+
EventStreamStore = HttpServerTransportOptions.EventStreamStore,
410+
};
411+
412+
// Initialize the transport with the migrated session's init params.
413+
await transport.HandleInitializeRequestAsync(initializeParams);
414+
415+
context.Response.Headers[McpSessionIdHeaderName] = sessionId;
416+
417+
return await CreateSessionAsync(context, transport, sessionId, options =>
418+
{
419+
options.KnownClientInfo = initializeParams.ClientInfo;
420+
options.KnownClientCapabilities = initializeParams.Capabilities;
421+
});
422+
}
423+
331424
private async ValueTask<ISseEventStreamReader?> GetEventStreamReaderAsync(HttpContext context, string lastEventId)
332425
{
333426
if (HttpServerTransportOptions.EventStreamStore is not { } eventStreamStore)

src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,31 @@ public async ValueTask<IAsyncDisposable> AcquireReferenceAsync(CancellationToken
7474
return new UnreferenceDisposable(this);
7575
}
7676

77+
/// <summary>
78+
/// Ensures the session is registered with the session manager without acquiring a reference.
79+
/// No-ops if the session is already started.
80+
/// </summary>
81+
public async ValueTask EnsureStartedAsync(CancellationToken cancellationToken)
82+
{
83+
bool needsStart;
84+
lock (_stateLock)
85+
{
86+
needsStart = _state == SessionState.Uninitialized;
87+
if (needsStart)
88+
{
89+
_state = SessionState.Started;
90+
}
91+
}
92+
93+
if (needsStart)
94+
{
95+
await sessionManager.StartNewSessionAsync(this, cancellationToken);
96+
97+
// Session is registered with 0 references (idle), so reflect that in the idle count.
98+
sessionManager.IncrementIdleSessionCount();
99+
}
100+
}
101+
77102
public bool TryStartGetRequest() => Interlocked.Exchange(ref _getRequestStarted, 1) == 0;
78103
public bool HasSameUserId(ClaimsPrincipal user) => userId == StreamableHttpHandler.GetUserIdClaim(user);
79104

src/ModelContextProtocol.Core/Server/McpServerImpl.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFact
7575
}
7676

7777
_clientInfo = options.KnownClientInfo;
78+
_clientCapabilities = options.KnownClientCapabilities;
7879
UpdateEndpointNameWithClientInfo();
7980

8081
_notificationHandlers = new();

src/ModelContextProtocol.Core/Server/McpServerOptions.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ public sealed class McpServerOptions
8282
/// </remarks>
8383
public Implementation? KnownClientInfo { get; set; }
8484

85+
/// <summary>
86+
/// Gets or sets preexisting knowledge about the client's capabilities to support session migration
87+
/// scenarios where the client will not re-send the initialize request.
88+
/// </summary>
89+
/// <remarks>
90+
/// <para>
91+
/// When not specified, this information is sourced from the client's initialize request.
92+
/// This is typically set during session migration in conjunction with <see cref="KnownClientInfo"/>.
93+
/// </para>
94+
/// </remarks>
95+
public ClientCapabilities? KnownClientCapabilities { get; set; }
96+
8597
/// <summary>
8698
/// Gets the filter collections for MCP server handlers.
8799
/// </summary>

src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public async ValueTask<bool> HandlePostAsync(JsonRpcMessage message, Cancellatio
5252
if (request.Method == RequestMethods.Initialize)
5353
{
5454
var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams);
55-
await parentTransport.HandleInitRequestAsync(initializeRequest).ConfigureAwait(false);
55+
await parentTransport.HandleInitializeRequestAsync(initializeRequest).ConfigureAwait(false);
5656
}
5757
}
5858

0 commit comments

Comments
 (0)