Skip to content

Commit d9e737c

Browse files
committed
Remove MaxReconnectAttempts and ReconnectDelay from SseClientTransportOptions
- Add proper AdditionalHeaders support
1 parent 3167bdc commit d9e737c

4 files changed

Lines changed: 142 additions & 112 deletions

File tree

src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs

Lines changed: 56 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,12 @@ public override async Task SendMessageAsync(
101101
messageId = messageWithId.Id.ToString();
102102
}
103103

104-
var response = await _httpClient.PostAsync(
105-
_messageEndpoint,
106-
content,
107-
cancellationToken
108-
).ConfigureAwait(false);
104+
var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint)
105+
{
106+
Content = content,
107+
};
108+
CopyAdditionalHeaders(httpRequestMessage.Headers);
109+
var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
109110

110111
response.EnsureSuccessStatusCode();
111112

@@ -182,72 +183,52 @@ public override async ValueTask DisposeAsync()
182183

183184
private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
184185
{
185-
int reconnectAttempts = 0;
186-
187-
while (!cancellationToken.IsCancellationRequested && !IsConnected)
186+
try
188187
{
189-
try
190-
{
191-
using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint);
192-
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
188+
using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint);
189+
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
190+
CopyAdditionalHeaders(request.Headers);
193191

194-
if (_options.AdditionalHeaders != null)
195-
{
196-
foreach (var header in _options.AdditionalHeaders)
197-
{
198-
request.Headers.Add(header.Key, header.Value);
199-
}
200-
}
201-
202-
using var response = await _httpClient.SendAsync(
203-
request,
204-
HttpCompletionOption.ResponseHeadersRead,
205-
cancellationToken
206-
).ConfigureAwait(false);
192+
using var response = await _httpClient.SendAsync(
193+
request,
194+
HttpCompletionOption.ResponseHeadersRead,
195+
cancellationToken
196+
).ConfigureAwait(false);
207197

208-
response.EnsureSuccessStatusCode();
198+
response.EnsureSuccessStatusCode();
209199

210-
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
200+
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
211201

212-
await foreach (SseItem<string> sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false))
213-
{
214-
switch (sseEvent.EventType)
215-
{
216-
case "endpoint":
217-
HandleEndpointEvent(sseEvent.Data);
218-
break;
219-
220-
case "message":
221-
await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false);
222-
break;
223-
}
224-
}
225-
}
226-
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
227-
{
228-
_logger.TransportReadMessagesCancelled(_endpointName);
229-
// Normal shutdown
230-
}
231-
catch (IOException) when (cancellationToken.IsCancellationRequested)
232-
{
233-
_logger.TransportReadMessagesCancelled(_endpointName);
234-
// Normal shutdown
235-
}
236-
catch (Exception ex) when (!cancellationToken.IsCancellationRequested)
202+
await foreach (SseItem<string> sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false))
237203
{
238-
_logger.TransportConnectionError(_endpointName, ex);
239-
240-
reconnectAttempts++;
241-
if (reconnectAttempts >= _options.MaxReconnectAttempts)
204+
switch (sseEvent.EventType)
242205
{
243-
throw new McpTransportException("Exceeded reconnect limit", ex);
244-
}
206+
case "endpoint":
207+
HandleEndpointEvent(sseEvent.Data);
208+
break;
245209

246-
await Task.Delay(_options.ReconnectDelay, cancellationToken).ConfigureAwait(false);
210+
case "message":
211+
await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false);
212+
break;
213+
}
247214
}
248215
}
249-
250-
SetConnected(false);
216+
catch when (cancellationToken.IsCancellationRequested)
217+
{
218+
// Normal shutdown
219+
_connectionEstablished.TrySetCanceled(cancellationToken);
220+
_logger.TransportReadMessagesCancelled(_endpointName);
221+
}
222+
catch (Exception ex) when (!cancellationToken.IsCancellationRequested)
223+
{
224+
_connectionEstablished.TrySetException(ex);
225+
_logger.TransportConnectionError(_endpointName, ex);
226+
throw;
227+
}
228+
finally
229+
{
230+
SetConnected(false);
231+
}
251232
}
252233

253234
private async Task ProcessSseMessage(string data, CancellationToken cancellationToken)
@@ -306,4 +287,18 @@ private void HandleEndpointEvent(string data)
306287
throw new McpTransportException("Failed to parse endpoint event", ex);
307288
}
308289
}
290+
291+
private void CopyAdditionalHeaders(HttpRequestHeaders headers)
292+
{
293+
if (_options.AdditionalHeaders is not null)
294+
{
295+
foreach (var header in _options.AdditionalHeaders)
296+
{
297+
if (!headers.TryAddWithoutValidation(header.Key, header.Value))
298+
{
299+
throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}.");
300+
}
301+
}
302+
}
303+
}
309304
}

src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,37 +48,6 @@ public required Uri Endpoint
4848
/// </remarks>
4949
public TimeSpan ConnectionTimeout { get; init; } = TimeSpan.FromSeconds(30);
5050

51-
/// <summary>
52-
/// Gets or sets the maximum number of reconnection attempts for the SSE connection before giving up.
53-
/// </summary>
54-
/// <remarks>
55-
/// <para>
56-
/// This property controls how many times the client will attempt to reconnect to the SSE server
57-
/// after a connection failure occurs. If all reconnection attempts fail, a
58-
/// <see cref="McpTransportException"/> with the message "Exceeded reconnect limit" will be thrown.
59-
/// </para>
60-
/// <para>
61-
/// Between each reconnection attempt, the client will wait for the duration specified by <see cref="ReconnectDelay"/>.
62-
/// </para>
63-
/// </remarks>
64-
public int MaxReconnectAttempts { get; init; } = 3;
65-
66-
/// <summary>
67-
/// Gets or sets the delay to employ between reconnection attempts when the SSE connection fails.
68-
/// </summary>
69-
/// <remarks>
70-
/// <para>
71-
/// When a connection to the SSE server is lost or fails, the client will wait for this duration
72-
/// before attempting to reconnect. This helps prevent excessive reconnection attempts in quick succession
73-
/// which could overload the server or network.
74-
/// </para>
75-
/// <para>
76-
/// The reconnection process continues until either a successful connection is established or
77-
/// the maximum number of reconnection attempts (<see cref="MaxReconnectAttempts"/>) is reached.
78-
/// </para>
79-
/// </remarks>
80-
public TimeSpan ReconnectDelay { get; init; } = TimeSpan.FromSeconds(5);
81-
8251
/// <summary>
8352
/// Gets custom HTTP headers to include in requests to the SSE server.
8453
/// </summary>

tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : Kestr
2323
Name = "In-memory Test Server",
2424
};
2525

26-
private Task<IMcpClient> ConnectMcpClient(HttpClient httpClient, McpClientOptions? clientOptions = null)
26+
private Task<IMcpClient> ConnectMcpClient(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null)
2727
=> McpClientFactory.CreateAsync(
28-
new SseClientTransport(DefaultTransportOptions, httpClient, LoggerFactory),
29-
clientOptions,
30-
LoggerFactory,
31-
TestContext.Current.CancellationToken);
28+
new SseClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory),
29+
loggerFactory: LoggerFactory,
30+
cancellationToken: TestContext.Current.CancellationToken);
3231

3332
[Fact]
3433
public async Task ConnectAndReceiveMessage_InMemoryServer()
@@ -38,7 +37,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer()
3837
app.MapMcp();
3938
await app.StartAsync(TestContext.Current.CancellationToken);
4039

41-
await using var mcpClient = await ConnectMcpClient(HttpClient);
40+
await using var mcpClient = await ConnectMcpClient();
4241

4342
// Send a test message through POST endpoint
4443
await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken);
@@ -53,7 +52,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU
5352
MapAbsoluteEndpointUriMcp(app);
5453
await app.StartAsync(TestContext.Current.CancellationToken);
5554

56-
await using var mcpClient = await ConnectMcpClient(HttpClient);
55+
await using var mcpClient = await ConnectMcpClient();
5756

5857
// Send a test message through POST endpoint
5958
await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken);
@@ -85,7 +84,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer()
8584
app.MapMcp();
8685
await app.StartAsync(TestContext.Current.CancellationToken);
8786

88-
await using var mcpClient = await ConnectMcpClient(HttpClient);
87+
await using var mcpClient = await ConnectMcpClient();
8988

9089
mcpClient.RegisterNotificationHandler("test/notification", (args, ca) =>
9190
{
@@ -109,14 +108,14 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes()
109108

110109
Builder.Services.AddMcpServer(options =>
111110
{
112-
Interlocked.Increment(ref firstOptionsCallbackCallCount);
111+
firstOptionsCallbackCallCount++;
113112
})
114113
.WithHttpTransport()
115114
.WithTools<EchoTool>();
116115

117116
Builder.Services.AddMcpServer(options =>
118117
{
119-
Interlocked.Increment(ref secondOptionsCallbackCallCount);
118+
secondOptionsCallbackCallCount++;
120119
})
121120
.WithTools<SampleLlmTool>();
122121

@@ -125,7 +124,7 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes()
125124
app.MapMcp();
126125
await app.StartAsync(TestContext.Current.CancellationToken);
127126

128-
await using var mcpClient = await ConnectMcpClient(HttpClient);
127+
await using var mcpClient = await ConnectMcpClient();
129128

130129
// Options can be lazily initialized, but they must be instantiated by the time an MCP client can finish connecting.
131130
// Callbacks can be called multiple times if configureOptionsAsync is configured, because that uses the IOptionsFactory,
@@ -151,6 +150,78 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes()
151150
Assert.Equal("hello from client!", textContent.Text);
152151
}
153152

153+
[Fact]
154+
public async Task AdditionalHeaders_AreSent_InGetAndPostRequests()
155+
{
156+
Builder.Services.AddMcpServer()
157+
.WithHttpTransport();
158+
159+
await using var app = Builder.Build();
160+
161+
bool wasGetRequest = false;
162+
bool wasPostRequest = false;
163+
164+
app.Use(next =>
165+
{
166+
return async context =>
167+
{
168+
Assert.Equal("Bearer testToken", context.Request.Headers["Authorize"]);
169+
if (context.Request.Method == HttpMethods.Get)
170+
{
171+
wasGetRequest = true;
172+
}
173+
else if (context.Request.Method == HttpMethods.Post)
174+
{
175+
wasPostRequest = true;
176+
}
177+
await next(context);
178+
};
179+
});
180+
181+
app.MapMcp();
182+
await app.StartAsync(TestContext.Current.CancellationToken);
183+
184+
var sseOptions = new SseClientTransportOptions()
185+
{
186+
Endpoint = new Uri("http://localhost/sse"),
187+
Name = "In-memory Test Server",
188+
AdditionalHeaders = new()
189+
{
190+
["Authorize"] = "Bearer testToken"
191+
},
192+
};
193+
194+
await using var mcpClient = await ConnectMcpClient(transportOptions: sseOptions);
195+
196+
Assert.True(wasGetRequest);
197+
Assert.True(wasPostRequest);
198+
}
199+
200+
[Fact]
201+
public async Task EmptyAdditionalHeadersKey_Throws_InvalidOpearionException()
202+
{
203+
Builder.Services.AddMcpServer()
204+
.WithHttpTransport();
205+
206+
await using var app = Builder.Build();
207+
208+
app.MapMcp();
209+
await app.StartAsync(TestContext.Current.CancellationToken);
210+
211+
var sseOptions = new SseClientTransportOptions()
212+
{
213+
Endpoint = new Uri("http://localhost/sse"),
214+
Name = "In-memory Test Server",
215+
AdditionalHeaders = new()
216+
{
217+
[""] = ""
218+
},
219+
};
220+
221+
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => ConnectMcpClient(transportOptions: sseOptions));
222+
Assert.Equal("Failed to add header '' with value '' from AdditionalHeaders.", ex.Message);
223+
}
224+
154225
private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints)
155226
{
156227
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();

tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper)
1717
{
1818
Endpoint = new Uri("http://localhost:8080"),
1919
ConnectionTimeout = TimeSpan.FromSeconds(2),
20-
MaxReconnectAttempts = 3,
21-
ReconnectDelay = TimeSpan.FromMilliseconds(50),
2220
Name = "Test Server",
2321
AdditionalHeaders = new Dictionary<string, string>
2422
{
@@ -76,15 +74,12 @@ public async Task ConnectAsync_Throws_Exception_On_Failure()
7674
mockHttpHandler.RequestHandler = (request) =>
7775
{
7876
retries++;
79-
throw new InvalidOperationException("Test exception");
77+
throw new Exception("Test exception");
8078
};
8179

82-
var action = async () => await transport.ConnectAsync();
83-
84-
var exception = await Assert.ThrowsAsync<McpTransportException>(action);
85-
Assert.Equal("Exceeded reconnect limit", exception.Message);
86-
87-
Assert.Equal(_transportOptions.MaxReconnectAttempts, retries);
80+
var exception = await Assert.ThrowsAsync<Exception>(() => transport.ConnectAsync(TestContext.Current.CancellationToken));
81+
Assert.Equal("Test exception", exception.Message);
82+
Assert.Equal(1, retries);
8883
}
8984

9085
[Fact]

0 commit comments

Comments
 (0)