-
Notifications
You must be signed in to change notification settings - Fork 670
Expand file tree
/
Copy pathStreamableHttpClientSessionTransport.cs
More file actions
558 lines (476 loc) · 22.3 KB
/
StreamableHttpClientSessionTransport.cs
File metadata and controls
558 lines (476 loc) · 22.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using System.Diagnostics;
using System.Net.Http.Headers;
using System.Net.ServerSentEvents;
using System.Text.Json;
using ModelContextProtocol.Protocol;
using System.Threading.Channels;
using System.Net;
namespace ModelContextProtocol.Client;
/// <summary>
/// The Streamable HTTP client transport implementation
/// </summary>
internal sealed partial class StreamableHttpClientSessionTransport : TransportBase
{
private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json");
private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream");
private readonly McpHttpClient _httpClient;
private readonly HttpClientTransportOptions _options;
private readonly CancellationTokenSource _connectionCts = new();
private readonly ILogger _logger;
private string? _negotiatedProtocolVersion;
private Task? _getReceiveTask;
private volatile ClientTransportClosedException? _disconnectError;
private readonly SemaphoreSlim _disposeLock = new(1, 1);
private bool _disposed;
public StreamableHttpClientSessionTransport(
string endpointName,
HttpClientTransportOptions transportOptions,
McpHttpClient httpClient,
Channel<JsonRpcMessage>? messageChannel,
ILoggerFactory? loggerFactory)
: base(endpointName, messageChannel, loggerFactory)
{
Throw.IfNull(transportOptions);
Throw.IfNull(httpClient);
_options = transportOptions;
_httpClient = httpClient;
_logger = (ILogger?)loggerFactory?.CreateLogger<HttpClientTransport>() ?? NullLogger.Instance;
// We connect with the initialization request with the MCP transport. This means that any errors won't be observed
// until the first call to SendMessageAsync. Fortunately, that happens internally in McpClient.ConnectAsync
// so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user.
SetConnected();
if (_options.KnownSessionId is { } knownSessionId)
{
SessionId = knownSessionId;
_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
}
}
/// <inheritdoc/>
public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
// Immediately dispose the response. SendHttpRequestAsync only returns the response so the auto transport can look at it.
using var response = await SendHttpRequestAsync(message, cancellationToken).ConfigureAwait(false);
await response.EnsureSuccessStatusCodeWithResponseBodyAsync(cancellationToken).ConfigureAwait(false);
}
// This is used by the auto transport so it can fall back and try SSE given a non-200 response without catching an exception.
internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
if (_options.KnownSessionId is not null &&
message is JsonRpcRequest { Method: RequestMethods.Initialize })
{
throw new InvalidOperationException(
$"Cannot send '{RequestMethods.Initialize}' when {nameof(HttpClientTransportOptions)}.{nameof(HttpClientTransportOptions.KnownSessionId)} is configured. " +
$"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions.");
}
LogTransportSendingMessageSensitive(message);
using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token);
cancellationToken = sendCts.Token;
using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _options.Endpoint)
{
Headers =
{
Accept = { s_applicationJsonMediaType, s_textEventStreamMediaType },
},
};
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion);
HttpResponseMessage response;
try
{
response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
LogHttpPostRequestFailed(Name, ex);
throw;
}
// We'll let the caller decide whether to throw or fall back given an unsuccessful response.
if (!response.IsSuccessStatusCode)
{
LogHttpPostNonSuccessStatusCode(Name, (int)response.StatusCode);
// Per the MCP spec, a 404 response to a request containing an Mcp-Session-Id
// indicates the session has ended. Signal completion so McpClient.Completion resolves.
if (response.StatusCode == HttpStatusCode.NotFound && SessionId is not null)
{
SetSessionExpired();
}
return response;
}
var rpcRequest = message as JsonRpcRequest;
JsonRpcMessageWithId? rpcResponseOrError = null;
if (response.Content.Headers.ContentType?.MediaType == "application/json")
{
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
rpcResponseOrError = await ProcessMessageAsync(responseContent, rpcRequest, cancellationToken).ConfigureAwait(false);
}
else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream")
{
var sseState = new SseStreamState();
using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var sseResponse = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, sseState, cancellationToken).ConfigureAwait(false);
rpcResponseOrError = sseResponse.Response;
// Resumability: If POST SSE stream ended without a response but we have a Last-Event-ID (from priming),
// attempt to resume by sending a GET request with Last-Event-ID header. The server will replay
// events from the event store, allowing us to receive the pending response.
if (rpcResponseOrError is null && rpcRequest is not null && sseState.LastEventId is not null)
{
rpcResponseOrError = await SendGetSseRequestWithRetriesAsync(rpcRequest, sseState, cancellationToken).ConfigureAwait(false);
}
}
if (rpcRequest is null)
{
return response;
}
if (rpcResponseOrError is null)
{
throw new McpException($"Streamable HTTP POST response completed without a reply to request with ID: {rpcRequest.Id}");
}
if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseOrError is JsonRpcResponse initResponse)
{
// We've successfully initialized! Copy session-id and protocol version, then start GET request if any.
if (response.Headers.TryGetValues("Mcp-Session-Id", out var sessionIdValues))
{
SessionId = sessionIdValues.FirstOrDefault();
}
var initializeResult = JsonSerializer.Deserialize(initResponse.Result, McpJsonUtilities.JsonContext.Default.InitializeResult);
_negotiatedProtocolVersion = initializeResult?.ProtocolVersion;
_getReceiveTask ??= ReceiveUnsolicitedMessagesAsync();
}
return response;
}
public override async ValueTask DisposeAsync()
{
using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
if (_disposed)
{
return;
}
_disposed = true;
try
{
await _connectionCts.CancelAsync().ConfigureAwait(false);
try
{
// Send DELETE request to terminate the session. Only send if we have a session ID, per MCP spec.
if (_options.OwnsSession && !string.IsNullOrEmpty(SessionId))
{
await SendDeleteRequest().ConfigureAwait(false);
}
if (_getReceiveTask != null)
{
await _getReceiveTask.ConfigureAwait(false);
}
}
catch (OperationCanceledException)
{
}
catch (Exception ex)
{
LogTransportShutdownFailed(Name, ex);
}
}
finally
{
// If we're auto-detecting the transport and failed to connect, leave the message Channel open for the SSE transport.
// This class isn't directly exposed to public callers, so we don't have to worry about changing the _state in this case.
if (_options.TransportMode is not HttpTransportMode.AutoDetect || _getReceiveTask is not null)
{
// _disconnectError is set when the server returns 404 indicating session expiry.
// When null, this is a graceful client-initiated closure (no error).
SetDisconnected(_disconnectError ?? new ClientTransportClosedException(new HttpClientCompletionDetails()));
}
}
}
private async Task ReceiveUnsolicitedMessagesAsync()
{
var state = new SseStreamState();
// Continuously receive unsolicited messages until canceled or disconnected
while (!_connectionCts.Token.IsCancellationRequested && IsConnected)
{
await SendGetSseRequestWithRetriesAsync(
relatedRpcRequest: null,
state,
_connectionCts.Token).ConfigureAwait(false);
// If we exhausted retries without receiving any events, stop trying
if (state.LastEventId is null)
{
return;
}
}
}
/// <summary>
/// Sends a GET request for SSE with retry logic and resumability support.
/// </summary>
private async Task<JsonRpcMessageWithId?> SendGetSseRequestWithRetriesAsync(
JsonRpcRequest? relatedRpcRequest,
SseStreamState state,
CancellationToken cancellationToken)
{
// When LastEventId is null, the first attempt is the initial GET SSE connection (not a reconnection),
// so we start at -1 to avoid counting it against MaxReconnectionAttempts.
// When LastEventId is already set, all attempts are true reconnections, so we start at 0.
int attempt = state.LastEventId is null ? -1 : 0;
// Delay before first attempt if we're reconnecting (have a Last-Event-ID)
bool shouldDelay = state.LastEventId is not null;
while (attempt < _options.MaxReconnectionAttempts)
{
cancellationToken.ThrowIfCancellationRequested();
if (shouldDelay)
{
var delay = state.RetryInterval ?? _options.DefaultReconnectionInterval;
// Subtract time already elapsed since the SSE stream ended to more accurately
// honor the retry interval. Without this, processing overhead (HTTP response
// disposal, condition checks, etc.) inflates the observed reconnection delay.
if (state.StreamEndedTimestamp != 0)
{
delay -= ElapsedSince(state.StreamEndedTimestamp);
}
if (delay > TimeSpan.Zero)
{
await Task.Delay(delay, cancellationToken).ConfigureAwait(false);
}
}
shouldDelay = true;
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
request.Headers.Accept.Add(s_textEventStreamMediaType);
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion, state.LastEventId);
HttpResponseMessage response;
try
{
response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false);
}
catch (HttpRequestException ex)
{
LogHttpGetSseRequestFailed(Name, ex);
attempt++;
continue;
}
using (response)
{
if (response.StatusCode >= HttpStatusCode.InternalServerError)
{
// Server error; retry.
LogHttpGetSseNonSuccessStatusCode(Name, (int)response.StatusCode);
attempt++;
continue;
}
if (!response.IsSuccessStatusCode)
{
LogHttpGetSseNonSuccessStatusCode(Name, (int)response.StatusCode);
// Per the MCP spec, a 404 response to a request containing an Mcp-Session-Id
// indicates the session has ended. Signal completion so McpClient.Completion resolves.
if (response.StatusCode == HttpStatusCode.NotFound && SessionId is not null)
{
SetSessionExpired();
}
// If the server could be reached but returned a non-success status code,
// retrying likely won't change that.
return null;
}
using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var sseResponse = await ProcessSseResponseAsync(responseStream, relatedRpcRequest, state, cancellationToken).ConfigureAwait(false);
if (sseResponse.Response is { } rpcResponseOrError)
{
return rpcResponseOrError;
}
// If we reach here, then the stream closed without the response.
if (sseResponse.IsNetworkError || state.LastEventId is null)
{
// No event ID means server may not support resumability; don't retry indefinitely.
attempt++;
}
else
{
// We have an event ID, so we continue polling to receive more events.
// The server should eventually send a response or return an error.
attempt = 0;
}
}
}
return null;
}
private async Task<SseResponse> ProcessSseResponseAsync(
Stream responseStream,
JsonRpcRequest? relatedRpcRequest,
SseStreamState state,
CancellationToken cancellationToken)
{
try
{
await foreach (SseItem<string> sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false))
{
// Track event ID and retry interval for resumability
if (!string.IsNullOrEmpty(sseEvent.EventId))
{
state.LastEventId = sseEvent.EventId;
}
if (sseEvent.ReconnectionInterval.HasValue)
{
state.RetryInterval = sseEvent.ReconnectionInterval.Value;
}
// Skip events with empty data
if (string.IsNullOrEmpty(sseEvent.Data))
{
continue;
}
var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false);
if (rpcResponseOrError is not null)
{
return new() { Response = rpcResponseOrError };
}
}
}
catch (Exception ex) when (ex is IOException or HttpRequestException)
{
state.StreamEndedTimestamp = Stopwatch.GetTimestamp();
return new() { IsNetworkError = true };
}
state.StreamEndedTimestamp = Stopwatch.GetTimestamp();
return default;
}
private async Task<JsonRpcMessageWithId?> ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
{
LogTransportReceivedMessageSensitive(Name, data);
try
{
var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
if (message is null)
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, data);
return null;
}
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
if (message is JsonRpcResponse or JsonRpcError &&
message is JsonRpcMessageWithId rpcResponseOrError &&
rpcResponseOrError.Id == relatedRpcRequest?.Id)
{
return rpcResponseOrError;
}
}
catch (JsonException ex)
{
LogJsonException(ex, data);
}
return null;
}
private async Task SendDeleteRequest()
{
using var deleteRequest = new HttpRequestMessage(HttpMethod.Delete, _options.Endpoint);
CopyAdditionalHeaders(deleteRequest.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion);
HttpResponseMessage response;
try
{
response = await _httpClient.SendAsync(deleteRequest, message: null, CancellationToken.None).ConfigureAwait(false);
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
LogHttpDeleteRequestFailed(Name, ex);
return;
}
using (response)
{
// Server support for the DELETE request is optional, so a 405 Method Not Allowed is expected.
if (!response.IsSuccessStatusCode)
{
LogHttpDeleteNonSuccessStatusCode(Name, (int)response.StatusCode);
}
}
}
private void LogJsonException(JsonException ex, string data)
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseFailedSensitive(Name, data, ex);
}
else
{
LogTransportMessageParseFailed(Name, ex);
}
}
internal static void CopyAdditionalHeaders(
HttpRequestHeaders headers,
IDictionary<string, string>? additionalHeaders,
string? sessionId,
string? protocolVersion,
string? lastEventId = null)
{
if (sessionId is not null)
{
headers.Add("Mcp-Session-Id", sessionId);
}
if (protocolVersion is not null)
{
headers.Add("MCP-Protocol-Version", protocolVersion);
}
if (lastEventId is not null)
{
headers.Add("Last-Event-ID", lastEventId);
}
if (additionalHeaders is null)
{
return;
}
foreach (var header in additionalHeaders)
{
if (!headers.TryAddWithoutValidation(header.Key, header.Value))
{
throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(HttpClientTransportOptions.AdditionalHeaders)}.");
}
}
}
/// <summary>
/// Tracks state across SSE stream connections.
/// </summary>
private sealed class SseStreamState
{
public string? LastEventId { get; set; }
public TimeSpan? RetryInterval { get; set; }
/// <summary>Timestamp (via Stopwatch.GetTimestamp()) when the last SSE stream ended, used to discount processing overhead from the retry delay.</summary>
public long StreamEndedTimestamp { get; set; }
}
/// <summary>
/// Represents the result of processing an SSE response.
/// </summary>
private readonly struct SseResponse
{
public JsonRpcMessageWithId? Response { get; init; }
public bool IsNetworkError { get; init; }
}
private static TimeSpan ElapsedSince(long stopwatchTimestamp)
{
#if NET
return Stopwatch.GetElapsedTime(stopwatchTimestamp);
#else
return TimeSpan.FromSeconds((double)(Stopwatch.GetTimestamp() - stopwatchTimestamp) / Stopwatch.Frequency);
#endif
}
private void SetSessionExpired()
{
// Store the error before canceling so DisposeAsync can use it if it races us, especially
// after the call to Cancel below, to invoke SetDisconnected.
_disconnectError = new ClientTransportClosedException(new HttpClientCompletionDetails
{
HttpStatusCode = HttpStatusCode.NotFound,
Exception = new McpException(
"The server returned HTTP 404 for a request with an Mcp-Session-Id, indicating the session has expired. " +
"To continue, create a new client session or call ResumeSessionAsync with a new connection."),
});
// Cancel to unblock any in-flight operations (e.g., SSE stream reads in
// SendGetSseRequestWithRetriesAsync) that are waiting on _connectionCts.Token.
_connectionCts.Cancel();
SetDisconnected(_disconnectError);
}
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} HTTP POST request failed.")]
private partial void LogHttpPostRequestFailed(string endpointName, Exception exception);
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} HTTP POST received non-success status code {StatusCode}.")]
private partial void LogHttpPostNonSuccessStatusCode(string endpointName, int statusCode);
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} HTTP GET SSE request failed.")]
private partial void LogHttpGetSseRequestFailed(string endpointName, Exception exception);
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} HTTP GET SSE received non-success status code {StatusCode}.")]
private partial void LogHttpGetSseNonSuccessStatusCode(string endpointName, int statusCode);
[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} HTTP DELETE request failed.")]
private partial void LogHttpDeleteRequestFailed(string endpointName, Exception exception);
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} HTTP DELETE received non-success status code {StatusCode}.")]
private partial void LogHttpDeleteNonSuccessStatusCode(string endpointName, int statusCode);
}