77using Microsoft . Net . Http . Headers ;
88using ModelContextProtocol . Protocol ;
99using ModelContextProtocol . Server ;
10+ using System . Collections . Concurrent ;
1011using System . Security . Claims ;
1112using System . Security . Cryptography ;
1213using System . Text . Json . Serialization . Metadata ;
@@ -20,7 +21,8 @@ internal sealed class StreamableHttpHandler(
2021 StatefulSessionManager sessionManager ,
2122 IHostApplicationLifetime hostApplicationLifetime ,
2223 IServiceProvider applicationServices ,
23- ILoggerFactory loggerFactory )
24+ ILoggerFactory loggerFactory ,
25+ ISessionMigrationHandler ? sessionMigrationHandler = null )
2426{
2527 private const string McpSessionIdHeaderName = "Mcp-Session-Id" ;
2628 private const string McpProtocolVersionHeaderName = "MCP-Protocol-Version" ;
@@ -41,6 +43,11 @@ internal sealed class StreamableHttpHandler(
4143 private static readonly JsonTypeInfo < JsonRpcMessage > s_messageTypeInfo = GetRequiredJsonTypeInfo < JsonRpcMessage > ( ) ;
4244 private static readonly JsonTypeInfo < JsonRpcError > s_errorTypeInfo = GetRequiredJsonTypeInfo < JsonRpcError > ( ) ;
4345
46+ private static bool AllowNewSessionForNonInitializeRequests { get ; } =
47+ AppContext . TryGetSwitch ( "ModelContextProtocol.AspNetCore.AllowNewSessionForNonInitializeRequests" , out var enabled ) && enabled ;
48+
49+ private readonly ConcurrentDictionary < string , SemaphoreSlim > _migrationLocks = new ( StringComparer . Ordinal ) ;
50+
4451 public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions . Value ;
4552
4653 public async Task HandlePostRequestAsync ( HttpContext context )
@@ -64,14 +71,6 @@ await WriteJsonRpcErrorAsync(context,
6471 return ;
6572 }
6673
67- var session = await GetOrCreateSessionAsync ( context ) ;
68- if ( session is null )
69- {
70- return ;
71- }
72-
73- await using var _ = await session . AcquireReferenceAsync ( context . RequestAborted ) ;
74-
7574 var message = await ReadJsonRpcMessageAsync ( context ) ;
7675 if ( message is null )
7776 {
@@ -81,6 +80,14 @@ await WriteJsonRpcErrorAsync(context,
8180 return ;
8281 }
8382
83+ var session = await GetOrCreateSessionAsync ( context , message ) ;
84+ if ( session is null )
85+ {
86+ return ;
87+ }
88+
89+ await using var _ = await session . AcquireReferenceAsync ( context . RequestAborted ) ;
90+
8491 InitializeSseResponse ( context ) ;
8592 var wroteResponse = await session . Transport . HandlePostRequestAsync ( message , context . Response . Body , context . RequestAborted ) ;
8693 if ( ! wroteResponse )
@@ -219,12 +226,18 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
219226
220227 if ( ! sessionManager . TryGetValue ( sessionId , out var session ) )
221228 {
222- // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
223- // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
224- // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
225- // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields
226- await WriteJsonRpcErrorAsync ( context , "Session not found" , StatusCodes . Status404NotFound , - 32001 ) ;
227- return null ;
229+ // Session not found locally. Attempt migration if a handler is registered.
230+ session = await TryMigrateSessionAsync ( context , sessionId ) ;
231+
232+ if ( session is null )
233+ {
234+ // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
235+ // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
236+ // JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
237+ // https://learn.microsoft.com/dotnet/api/streamjsonrpc.protocol.jsonrpcerrorcode?view=streamjsonrpc-2.9#fields
238+ await WriteJsonRpcErrorAsync ( context , "Session not found" , StatusCodes . Status404NotFound , - 32001 ) ;
239+ return null ;
240+ }
228241 }
229242
230243 if ( ! session . HasSameUserId ( context . User ) )
@@ -240,12 +253,61 @@ await WriteJsonRpcErrorAsync(context,
240253 return session ;
241254 }
242255
243- private async ValueTask < StreamableHttpSession ? > GetOrCreateSessionAsync ( HttpContext context )
256+ private async ValueTask < StreamableHttpSession ? > TryMigrateSessionAsync ( HttpContext context , string sessionId )
257+ {
258+ if ( sessionMigrationHandler is not { } handler )
259+ {
260+ return null ;
261+ }
262+
263+ var migrationLock = _migrationLocks . GetOrAdd ( sessionId , static _ => new SemaphoreSlim ( 1 , 1 ) ) ;
264+ await migrationLock . WaitAsync ( context . RequestAborted ) ;
265+ try
266+ {
267+ // Re-check after acquiring the lock - another thread may have already completed migration.
268+ if ( sessionManager . TryGetValue ( sessionId , out var session ) )
269+ {
270+ return session ;
271+ }
272+
273+ var initParams = await handler . AllowSessionMigrationAsync ( context , sessionId , context . RequestAborted ) ;
274+ if ( initParams is null )
275+ {
276+ return null ;
277+ }
278+
279+ var migratedSession = await MigrateSessionAsync ( context , sessionId , initParams ) ;
280+
281+ // Register the session with the session manager while still holding the lock
282+ // so concurrent requests for the same session ID find it via sessionManager.TryGetValue.
283+ await migratedSession . EnsureStartedAsync ( context . RequestAborted ) ;
284+
285+ return migratedSession ;
286+ }
287+ finally
288+ {
289+ migrationLock . Release ( ) ;
290+ _migrationLocks . TryRemove ( sessionId , out _ ) ;
291+ }
292+ }
293+
294+ private async ValueTask < StreamableHttpSession ? > GetOrCreateSessionAsync ( HttpContext context , JsonRpcMessage message )
244295 {
245296 var sessionId = context . Request . Headers [ McpSessionIdHeaderName ] . ToString ( ) ;
246297
247298 if ( string . IsNullOrEmpty ( sessionId ) )
248299 {
300+ // In stateful mode, only allow creating new sessions for initialize requests.
301+ // In stateless mode, every request is independent, so we always create a new session.
302+ if ( ! HttpServerTransportOptions . Stateless && ! AllowNewSessionForNonInitializeRequests
303+ && message is not JsonRpcRequest { Method : RequestMethods . Initialize } )
304+ {
305+ await WriteJsonRpcErrorAsync ( context ,
306+ "Bad Request: A new session can only be created by an initialize request. Include a valid Mcp-Session-Id header for non-initialize requests." ,
307+ StatusCodes . Status400BadRequest ) ;
308+ return null ;
309+ }
310+
249311 return await StartNewSessionAsync ( context ) ;
250312 }
251313 else if ( HttpServerTransportOptions . Stateless )
@@ -274,7 +336,11 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
274336 SessionId = sessionId ,
275337 FlowExecutionContextFromRequests = ! HttpServerTransportOptions . PerSessionExecutionContext ,
276338 EventStreamStore = HttpServerTransportOptions . EventStreamStore ,
339+ OnSessionInitialized = sessionMigrationHandler is { } handler
340+ ? ( initParams , ct ) => handler . OnSessionInitializedAsync ( context , sessionId , initParams , ct )
341+ : null ,
277342 } ;
343+
278344 context . Response . Headers [ McpSessionIdHeaderName ] = sessionId ;
279345 }
280346 else
@@ -295,11 +361,12 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
295361 private async ValueTask < StreamableHttpSession > CreateSessionAsync (
296362 HttpContext context ,
297363 StreamableHttpServerTransport transport ,
298- string sessionId )
364+ string sessionId ,
365+ Action < McpServerOptions > ? configureOptions = null )
299366 {
300367 var mcpServerServices = applicationServices ;
301368 var mcpServerOptions = mcpServerOptionsSnapshot . Value ;
302- if ( HttpServerTransportOptions . Stateless || HttpServerTransportOptions . ConfigureSessionOptions is not null )
369+ if ( HttpServerTransportOptions . Stateless || HttpServerTransportOptions . ConfigureSessionOptions is not null || configureOptions is not null )
303370 {
304371 mcpServerOptions = mcpServerOptionsFactory . Create ( Options . DefaultName ) ;
305372
@@ -310,6 +377,8 @@ private async ValueTask<StreamableHttpSession> CreateSessionAsync(
310377 mcpServerOptions . ScopeRequests = false ;
311378 }
312379
380+ configureOptions ? . Invoke ( mcpServerOptions ) ;
381+
313382 if ( HttpServerTransportOptions . ConfigureSessionOptions is { } configureSessionOptions )
314383 {
315384 await configureSessionOptions ( context , mcpServerOptions , context . RequestAborted ) ;
@@ -328,6 +397,30 @@ private async ValueTask<StreamableHttpSession> CreateSessionAsync(
328397 return session ;
329398 }
330399
400+ private async ValueTask < StreamableHttpSession > MigrateSessionAsync (
401+ HttpContext context ,
402+ string sessionId ,
403+ InitializeRequestParams initializeParams )
404+ {
405+ var transport = new StreamableHttpServerTransport ( loggerFactory )
406+ {
407+ SessionId = sessionId ,
408+ FlowExecutionContextFromRequests = ! HttpServerTransportOptions . PerSessionExecutionContext ,
409+ EventStreamStore = HttpServerTransportOptions . EventStreamStore ,
410+ } ;
411+
412+ // Initialize the transport with the migrated session's init params.
413+ await transport . HandleInitializeRequestAsync ( initializeParams ) ;
414+
415+ context . Response . Headers [ McpSessionIdHeaderName ] = sessionId ;
416+
417+ return await CreateSessionAsync ( context , transport , sessionId , options =>
418+ {
419+ options . KnownClientInfo = initializeParams . ClientInfo ;
420+ options . KnownClientCapabilities = initializeParams . Capabilities ;
421+ } ) ;
422+ }
423+
331424 private async ValueTask < ISseEventStreamReader ? > GetEventStreamReaderAsync ( HttpContext context , string lastEventId )
332425 {
333426 if ( HttpServerTransportOptions . EventStreamStore is not { } eventStreamStore )
0 commit comments