Skip to content

Commit 4b7aabb

Browse files
committed
Refactor transports to enable graceful shutdown
- This also paves the way for better multi-session support - We should definitely rethink names for the transport API - For now, I kept the names similar as possible, so we can focus on the API shape
1 parent fbe1dbe commit 4b7aabb

40 files changed

+1728
-1532
lines changed

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,37 @@
1-
using ModelContextProtocol.Configuration;
1+
using Microsoft.Extensions.Logging;
2+
using ModelContextProtocol.Configuration;
23
using ModelContextProtocol.Logging;
34
using ModelContextProtocol.Protocol.Messages;
45
using ModelContextProtocol.Protocol.Transport;
56
using ModelContextProtocol.Protocol.Types;
67
using ModelContextProtocol.Shared;
78
using ModelContextProtocol.Utils.Json;
8-
using Microsoft.Extensions.Logging;
99
using System.Text.Json;
1010

1111
namespace ModelContextProtocol.Client;
1212

1313
/// <inheritdoc/>
1414
internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient
1515
{
16-
private readonly McpClientOptions _options;
1716
private readonly IClientTransport _clientTransport;
17+
private readonly McpClientOptions _options;
1818

19-
private int _connecting;
19+
private ITransport? _sessionTransport;
20+
private CancellationTokenSource? _connectCts;
21+
private int _disposed;
2022

2123
/// <summary>
2224
/// Initializes a new instance of the <see cref="McpClient"/> class.
2325
/// </summary>
24-
/// <param name="transport">The transport to use for communication with the server.</param>
26+
/// <param name="clientTransport">The transport to use for communication with the server.</param>
2527
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
2628
/// <param name="serverConfig">The server configuration.</param>
2729
/// <param name="loggerFactory">The logger factory.</param>
28-
public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
29-
: base(transport, loggerFactory)
30+
public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
31+
: base(loggerFactory)
3032
{
33+
_clientTransport = clientTransport;
3134
_options = options;
32-
_clientTransport = transport;
3335

3436
EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
3537

@@ -70,25 +72,19 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer
7072
/// <inheritdoc/>
7173
public override string EndpointName { get; }
7274

73-
/// <inheritdoc/>
7475
public async Task ConnectAsync(CancellationToken cancellationToken = default)
7576
{
76-
if (Interlocked.Exchange(ref _connecting, 1) != 0)
77-
{
78-
_logger.ClientAlreadyInitializing(EndpointName);
79-
throw new InvalidOperationException("Client is already in use.");
80-
}
81-
82-
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
83-
cancellationToken = CancellationTokenSource.Token;
77+
_connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
78+
cancellationToken = _connectCts.Token;
8479

8580
try
8681
{
8782
// Connect transport
88-
await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
89-
90-
// Start processing messages
91-
MessageProcessingTask = ProcessMessagesAsync(cancellationToken);
83+
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
84+
InitializeSession(_sessionTransport);
85+
// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
86+
// The base class handles cleaning up the session in DisposeAsync without our help.
87+
StartSession(fullSessionCancellationToken: CancellationToken.None);
9288

9389
// Perform initialization sequence
9490
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
@@ -140,8 +136,27 @@ await SendMessageAsync(
140136
catch (Exception e)
141137
{
142138
_logger.ClientInitializationError(EndpointName, e);
143-
await CleanupAsync().ConfigureAwait(false);
139+
await DisposeAsync().ConfigureAwait(false);
144140
throw;
145141
}
146142
}
143+
144+
/// <inheritdoc/>
145+
public override async ValueTask DisposeAsync()
146+
{
147+
if (Interlocked.Exchange(ref _disposed, 1) != 0)
148+
{
149+
// TODO: It's more correct to await the last DisposeAsync before returning if it's still ongoing.
150+
return;
151+
}
152+
153+
if (_connectCts is not null)
154+
{
155+
await _connectCts.CancelAsync().ConfigureAwait(false);
156+
}
157+
158+
await base.DisposeAsync().ConfigureAwait(false);
159+
160+
_connectCts?.Dispose();
161+
}
147162
}

src/ModelContextProtocol/Client/McpClientFactory.cs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,16 @@ public static async Task<IMcpClient> CreateAsync(
6565
createTransportFunc(serverConfig, loggerFactory) ??
6666
throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport.");
6767

68+
McpClient client = new(transport, clientOptions, serverConfig, loggerFactory);
6869
try
6970
{
70-
McpClient client = new(transport, clientOptions, serverConfig, loggerFactory);
71-
try
72-
{
73-
await client.ConnectAsync(cancellationToken).ConfigureAwait(false);
74-
logger.ClientCreated(endpointName);
75-
return client;
76-
}
77-
catch
78-
{
79-
await client.DisposeAsync().ConfigureAwait(false);
80-
throw;
81-
}
71+
await client.ConnectAsync(cancellationToken).ConfigureAwait(false);
72+
logger.ClientCreated(endpointName);
73+
return client;
8274
}
8375
catch
8476
{
85-
await transport.DisposeAsync().ConfigureAwait(false);
77+
await client.DisposeAsync().ConfigureAwait(false);
8678
throw;
8779
}
8880
}

src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
using ModelContextProtocol.Protocol.Transport;
44
using ModelContextProtocol.Utils;
55
using Microsoft.Extensions.DependencyInjection;
6+
using Microsoft.Extensions.Logging;
7+
using Microsoft.Extensions.Options;
8+
using ModelContextProtocol.Server;
69

710
namespace ModelContextProtocol;
811

@@ -19,8 +22,18 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder
1922
{
2023
Throw.IfNull(builder);
2124

22-
builder.Services.AddSingleton<IServerTransport, StdioServerTransport>();
23-
builder.Services.AddHostedService<McpServerHostedService>();
25+
builder.Services.AddSingleton<ITransport, StdioServerTransport>();
26+
builder.Services.AddHostedService<McpServerSingleSessionHostedService>();
27+
28+
builder.Services.AddSingleton(services =>
29+
{
30+
ITransport serverTransport = services.GetRequiredService<ITransport>();
31+
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
32+
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();
33+
34+
return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
35+
});
36+
2437
return builder;
2538
}
2639

@@ -33,7 +46,7 @@ public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServ
3346
Throw.IfNull(builder);
3447

3548
builder.Services.AddSingleton<IServerTransport, HttpListenerSseServerTransport>();
36-
builder.Services.AddHostedService<McpServerHostedService>();
49+
builder.Services.AddHostedService<McpServerMultiSessionHostedService>();
3750
return builder;
3851
}
3952
}

src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
namespace ModelContextProtocol;
99

1010
/// <summary>
11-
/// Extension to host the MCP server
11+
/// Extension to host an MCP server
1212
/// </summary>
1313
public static class McpServerServiceCollectionExtension
1414
{
@@ -20,15 +20,6 @@ public static class McpServerServiceCollectionExtension
2020
/// <returns></returns>
2121
public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, Action<McpServerOptions>? configureOptions = null)
2222
{
23-
services.AddSingleton(services =>
24-
{
25-
IServerTransport serverTransport = services.GetRequiredService<IServerTransport>();
26-
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
27-
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();
28-
29-
return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
30-
});
31-
3223
services.AddOptions();
3324
services.AddTransient<IConfigureOptions<McpServerOptions>, McpServerOptionsSetup>();
3425
if (configureOptions is not null)

src/ModelContextProtocol/Hosting/McpServerHostedService.cs

Lines changed: 0 additions & 31 deletions
This file was deleted.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using Microsoft.Extensions.Hosting;
2+
using Microsoft.Extensions.Logging;
3+
using Microsoft.Extensions.Options;
4+
using ModelContextProtocol.Protocol.Transport;
5+
using ModelContextProtocol.Server;
6+
7+
namespace ModelContextProtocol.Hosting;
8+
9+
/// <summary>
10+
/// Hosted service for a multi-session (i.e. HTTP) MCP server.
11+
/// </summary>
12+
internal class McpServerMultiSessionHostedService : BackgroundService
13+
{
14+
private readonly IServerTransport _serverTransport;
15+
private readonly McpServerOptions _serverOptions;
16+
private readonly ILoggerFactory _loggerFactory;
17+
private readonly IServiceProvider _serviceProvider;
18+
19+
public McpServerMultiSessionHostedService(
20+
IServerTransport serverTransport,
21+
IOptions<McpServerOptions> serverOptions,
22+
ILoggerFactory loggerFactory,
23+
IServiceProvider serviceProvider)
24+
{
25+
_serverTransport = serverTransport;
26+
_serverOptions = serverOptions.Value;
27+
_loggerFactory = loggerFactory;
28+
_serviceProvider = serviceProvider;
29+
}
30+
31+
/// <inheritdoc />
32+
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
33+
{
34+
while (await AcceptSessionAsync(stoppingToken).ConfigureAwait(false) is { } server)
35+
{
36+
// TODO: Track all running sessions and wait for all sessions to complete for graceful shutdown.
37+
_ = server.RunAsync(stoppingToken);
38+
}
39+
}
40+
41+
private Task<IMcpServer> AcceptSessionAsync(CancellationToken cancellationToken)
42+
=> McpServerFactory.AcceptAsync(_serverTransport, _serverOptions, _loggerFactory, _serviceProvider, cancellationToken);
43+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Microsoft.Extensions.Hosting;
2+
using ModelContextProtocol.Server;
3+
4+
namespace ModelContextProtocol.Hosting;
5+
6+
/// <summary>
7+
/// Hosted service for a single-session (i.e stdio) MCP server.
8+
/// </summary>
9+
internal class McpServerSingleSessionHostedService(IMcpServer session) : BackgroundService
10+
{
11+
/// <inheritdoc />
12+
protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken);
13+
}

src/ModelContextProtocol/ModelContextProtocol.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
</PropertyGroup>
1515

1616
<ItemGroup>
17-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions"/>
17+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" />
1818
<PackageReference Include="Microsoft.Extensions.AI" />
1919
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" />
2020
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />

src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ public HttpListenerServerProvider(int port)
4444
public required Func<Stream, CancellationToken, Task> OnSseConnectionAsync { get; set; }
4545
public required Func<Stream, CancellationToken, Task<bool>> OnMessageAsync { get; set; }
4646

47-
/// <inheritdoc/>
48-
public async Task StartAsync(CancellationToken cancellationToken = default)
47+
public void Start()
4948
{
5049
if (Interlocked.CompareExchange(ref _state, StateRunning, StateNotStarted) != StateNotStarted)
5150
{
@@ -60,7 +59,7 @@ public async Task StartAsync(CancellationToken cancellationToken = default)
6059
{
6160
try
6261
{
63-
using var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownTokenSource.Token, cancellationToken);
62+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownTokenSource.Token);
6463
cts.Token.Register(_listener.Stop);
6564
while (!cts.IsCancellationRequested)
6665
{
@@ -100,6 +99,7 @@ public async ValueTask DisposeAsync()
10099
await _shutdownTokenSource.CancelAsync().ConfigureAwait(false);
101100
_listener.Stop();
102101
await _listeningTask.ConfigureAwait(false);
102+
await _completed.Task.ConfigureAwait(false);
103103
}
104104

105105
/// <summary>Gets a <see cref="Task"/> that completes when the server has finished its work.</summary>

0 commit comments

Comments
 (0)