forked from modelcontextprotocol/csharp-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMcpEndpointRouteBuilderExtensions.cs
More file actions
159 lines (139 loc) · 7.36 KB
/
McpEndpointRouteBuilderExtensions.cs
File metadata and controls
159 lines (139 loc) · 7.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
namespace Microsoft.AspNetCore.Builder;
/// <summary>
/// Extension methods for <see cref="IEndpointRouteBuilder"/> to add MCP endpoints.
/// </summary>
public static class McpEndpointRouteBuilderExtensions
{
/// <summary>
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern = "",
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
=> endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync);
/// <summary>
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints,
RoutePattern pattern,
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
{
ConcurrentDictionary<string, SseResponseStreamTransport> _sessions = new(StringComparer.Ordinal);
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
var optionsFactory = endpoints.ServiceProvider.GetRequiredService<IOptionsFactory<McpServerOptions>>();
var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService<IHostApplicationLifetime>();
var routeGroup = endpoints.MapGroup(pattern);
routeGroup.MapGet("/sse", async context =>
{
// If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
// which defaults to 30 seconds.
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
var cancellationToken = sseCts.Token;
var response = context.Response;
response.Headers.ContentType = "text/event-stream";
response.Headers.CacheControl = "no-cache,no-store";
// Make sure we disable all response buffering for SSE
context.Response.Headers.ContentEncoding = "identity";
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
var sessionId = MakeNewSessionId();
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
if (!_sessions.TryAdd(sessionId, transport))
{
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}
var options = optionsSnapshot.Value;
if (configureOptionsAsync is not null)
{
options = optionsFactory.Create(Options.DefaultName);
await configureOptionsAsync.Invoke(context, options, cancellationToken);
}
try
{
var transportTask = transport.RunAsync(cancellationToken);
try
{
await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider);
context.Features.Set(mcpServer);
runSessionAsync ??= RunSession;
await runSessionAsync(context, mcpServer, cancellationToken);
}
finally
{
await transport.DisposeAsync();
await transportTask;
}
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
}
finally
{
_sessions.TryRemove(sessionId, out _);
}
});
routeGroup.MapPost("/message", async context =>
{
if (!context.Request.Query.TryGetValue("sessionId", out var sessionId))
{
await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context);
return;
}
if (!_sessions.TryGetValue(sessionId.ToString(), out var transport))
{
await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
return;
}
var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted);
if (message is null)
{
await Results.BadRequest("No message in request body.").ExecuteAsync(context);
return;
}
await transport.OnMessageReceivedAsync(message, context.RequestAborted);
context.Response.StatusCode = StatusCodes.Status202Accepted;
await context.Response.WriteAsync("Accepted");
});
return routeGroup;
}
private static Task RunSession(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted)
=> session.RunAsync(requestAborted);
private static string MakeNewSessionId()
{
// 128 bits
Span<byte> buffer = stackalloc byte[16];
RandomNumberGenerator.Fill(buffer);
return WebEncoders.Base64UrlEncode(buffer);
}
}