Skip to content

Commit 4ae26c0

Browse files
committed
Enable graceful shutdown of servers
1 parent 259f11a commit 4ae26c0

25 files changed

Lines changed: 391 additions & 493 deletions

README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,7 @@ McpServerOptions options = new()
198198
};
199199

200200
await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options);
201-
202-
await server.StartAsync();
203-
204-
// Run until process is stopped by the client (parent process)
205-
await Task.Delay(Timeout.Infinite);
201+
await server.RunAsync();
206202
```
207203

208204
## Acknowledgements

samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ public static class McpEndpointRouteBuilderExtensions
1010
{
1111
public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints)
1212
{
13-
IMcpServer? server = null;
1413
SseResponseStreamTransport? transport = null;
1514
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
1615
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
@@ -19,17 +18,15 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en
1918

2019
routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) =>
2120
{
22-
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
23-
await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
24-
25-
await localServer.StartAsync(requestAborted);
26-
2721
response.Headers.ContentType = "text/event-stream";
2822
response.Headers.CacheControl = "no-cache";
2923

24+
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
25+
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
26+
3027
try
3128
{
32-
await transport.RunAsync(requestAborted);
29+
await transport.RunAsync(cancellationToken: requestAborted);
3330
}
3431
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
3532
{

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 52 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using ModelContextProtocol.Shared;
77
using ModelContextProtocol.Utils.Json;
88
using Microsoft.Extensions.Logging;
9-
using Microsoft.Extensions.Logging.Abstractions;
109
using System.Text.Json;
1110

1211
namespace ModelContextProtocol.Client;
@@ -17,7 +16,7 @@ internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient
1716
private readonly McpClientOptions _options;
1817
private readonly IClientTransport _clientTransport;
1918

20-
private volatile bool _isInitializing;
19+
private int _connecting;
2120

2221
/// <summary>
2322
/// Initializes a new instance of the <see cref="McpClient"/> class.
@@ -74,92 +73,75 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer
7473
/// <inheritdoc/>
7574
public async Task ConnectAsync(CancellationToken cancellationToken = default)
7675
{
77-
if (IsInitialized)
78-
{
79-
_logger.ClientAlreadyInitialized(EndpointName);
80-
return;
81-
}
82-
83-
if (_isInitializing)
76+
if (Interlocked.Exchange(ref _connecting, 1) != 0)
8477
{
8578
_logger.ClientAlreadyInitializing(EndpointName);
86-
throw new InvalidOperationException("Client is already initializing");
79+
throw new InvalidOperationException("Client is already in use.");
8780
}
8881

89-
_isInitializing = true;
82+
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
83+
cancellationToken = CancellationTokenSource.Token;
84+
9085
try
9186
{
92-
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
93-
9487
// Connect transport
95-
await _clientTransport.ConnectAsync(CancellationTokenSource.Token).ConfigureAwait(false);
88+
await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
9689

9790
// Start processing messages
98-
MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token);
91+
MessageProcessingTask = ProcessMessagesAsync(cancellationToken);
9992

10093
// Perform initialization sequence
101-
await InitializeAsync(CancellationTokenSource.Token).ConfigureAwait(false);
94+
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
95+
initializationCts.CancelAfter(_options.InitializationTimeout);
10296

103-
IsInitialized = true;
97+
try
98+
{
99+
// Send initialize request
100+
var initializeResponse = await SendRequestAsync<InitializeResult>(
101+
new JsonRpcRequest
102+
{
103+
Method = "initialize",
104+
Params = new
105+
{
106+
protocolVersion = _options.ProtocolVersion,
107+
capabilities = _options.Capabilities ?? new ClientCapabilities(),
108+
clientInfo = _options.ClientInfo
109+
}
110+
},
111+
initializationCts.Token).ConfigureAwait(false);
112+
113+
// Store server information
114+
_logger.ServerCapabilitiesReceived(EndpointName,
115+
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
116+
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));
117+
118+
ServerCapabilities = initializeResponse.Capabilities;
119+
ServerInfo = initializeResponse.ServerInfo;
120+
ServerInstructions = initializeResponse.Instructions;
121+
122+
// Validate protocol version
123+
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
124+
{
125+
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
126+
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
127+
}
128+
129+
// Send initialized notification
130+
await SendMessageAsync(
131+
new JsonRpcNotification { Method = "notifications/initialized" },
132+
initializationCts.Token).ConfigureAwait(false);
133+
}
134+
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)
135+
{
136+
_logger.ClientInitializationTimeout(EndpointName);
137+
throw new McpClientException("Initialization timed out");
138+
}
104139
}
105140
catch (Exception e)
106141
{
107142
_logger.ClientInitializationError(EndpointName, e);
108143
await CleanupAsync().ConfigureAwait(false);
109144
throw;
110145
}
111-
finally
112-
{
113-
_isInitializing = false;
114-
}
115-
}
116-
117-
private async Task InitializeAsync(CancellationToken cancellationToken)
118-
{
119-
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
120-
initializationCts.CancelAfter(_options.InitializationTimeout);
121-
122-
try
123-
{
124-
// Send initialize request
125-
var initializeResponse = await SendRequestAsync<InitializeResult>(
126-
new JsonRpcRequest
127-
{
128-
Method = "initialize",
129-
Params = new
130-
{
131-
protocolVersion = _options.ProtocolVersion,
132-
capabilities = _options.Capabilities ?? new ClientCapabilities(),
133-
clientInfo = _options.ClientInfo
134-
}
135-
},
136-
initializationCts.Token).ConfigureAwait(false);
137-
138-
// Store server information
139-
_logger.ServerCapabilitiesReceived(EndpointName,
140-
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
141-
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));
142-
143-
ServerCapabilities = initializeResponse.Capabilities;
144-
ServerInfo = initializeResponse.ServerInfo;
145-
ServerInstructions = initializeResponse.Instructions;
146-
147-
// Validate protocol version
148-
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
149-
{
150-
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
151-
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
152-
}
153-
154-
// Send initialized notification
155-
await SendMessageAsync(
156-
new JsonRpcNotification { Method = "notifications/initialized" },
157-
initializationCts.Token).ConfigureAwait(false);
158-
}
159-
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)
160-
{
161-
_logger.ClientInitializationTimeout(EndpointName);
162-
throw new McpClientException("Initialization timed out");
163-
}
164146
}
165147
}

src/ModelContextProtocol/Hosting/McpServerHostedService.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ public McpServerHostedService(IMcpServer server)
2626
/// <inheritdoc />
2727
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
2828
{
29-
await _server.StartAsync(stoppingToken).ConfigureAwait(false);
29+
await _server.RunAsync(cancellationToken: stoppingToken).ConfigureAwait(false);
3030
}
3131
}

0 commit comments

Comments
 (0)