Skip to content
Closed
2 changes: 1 addition & 1 deletion global.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"sdk": {
"version": "9.0.100",
"version": "8.0.115",
"rollForward": "minor"
}
}
163 changes: 163 additions & 0 deletions src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Protocol;
using System.Diagnostics;
using System.Threading.Channels;

namespace ModelContextProtocol.Client;

/// <summary>
/// A transport that automatically detects whether to use Streamable HTTP or SSE transport
/// by trying Streamable HTTP first and falling back to SSE if that fails.
/// </summary>
internal sealed partial class AutoDetectingClientSessionTransport : ITransport
{
private readonly SseClientTransportOptions _options;
private readonly HttpClient _httpClient;
private readonly ILoggerFactory? _loggerFactory;
private readonly ILogger _logger;
private readonly string _name;
private readonly DelegatingChannelReader<JsonRpcMessage> _delegatingChannelReader;

private StreamableHttpClientSessionTransport? _streamableHttpTransport;
private SseClientSessionTransport? _sseTransport;

public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName)
{
Throw.IfNull(transportOptions);
Throw.IfNull(httpClient);

_options = transportOptions;
_httpClient = httpClient;
_loggerFactory = loggerFactory;
_logger = (ILogger?)loggerFactory?.CreateLogger<AutoDetectingClientSessionTransport>() ?? NullLogger.Instance;
_name = endpointName;
_delegatingChannelReader = new DelegatingChannelReader<JsonRpcMessage>(this);
}

/// <summary>
/// Returns the active transport (either StreamableHttp or SSE)
/// </summary>
internal ITransport? ActiveTransport => _streamableHttpTransport != null ? (ITransport)_streamableHttpTransport : _sseTransport;

public ChannelReader<JsonRpcMessage> MessageReader => _delegatingChannelReader;

/// <inheritdoc/>
public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (_streamableHttpTransport == null && _sseTransport == null)
{
var rpcRequest = message as JsonRpcRequest;

// Try StreamableHttp first
_streamableHttpTransport = new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, _name);

try
{
LogAttemptingStreamableHttp(_name);
var response = await _streamableHttpTransport.SendInitialRequestAsync(message, cancellationToken).ConfigureAwait(false);

// If the status code is not success, fall back to SSE
if (!response.IsSuccessStatusCode)
{
LogStreamableHttpFailed(_name, response.StatusCode);

await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should dispose this in a finally block in case SendInitialRequestAsync throws.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added proper disposal in a finally block in commit 57be8df to ensure resources are cleaned up if SendInitialRequestAsync throws.

await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false);
return;
}

// Process the streamable HTTP response using the transport
await _streamableHttpTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);

// Signal that we have established a connection
LogUsingStreamableHttp(_name);
_delegatingChannelReader.SetConnected();
}
catch (Exception ex)
{
LogStreamableHttpException(_name, ex);

await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);

// Propagate the original exception
throw;
}
}
else if (_streamableHttpTransport != null)
{
await _streamableHttpTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
else if (_sseTransport != null)
{
await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
}

private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
_sseTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, _name);

try
{
LogAttemptingSSE(_name);
await _sseTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);

// Signal that we have established a connection
LogUsingSSE(_name);
_delegatingChannelReader.SetConnected();
}
catch (Exception ex)
{
LogSSEConnectionFailed(_name, ex);
_delegatingChannelReader.SetError(ex);
await _sseTransport.DisposeAsync().ConfigureAwait(false);
throw;
}
}

public async ValueTask DisposeAsync()
{
try
{
if (_streamableHttpTransport != null)
{
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
}

if (_sseTransport != null)
{
await _sseTransport.DisposeAsync().ConfigureAwait(false);
}
}
catch (Exception ex)
{
LogDisposeFailed(_name, ex);
}
}

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Attempting to connect using Streamable HTTP transport.")]
private partial void LogAttemptingStreamableHttp(string endpointName);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Streamable HTTP transport failed with status code {StatusCode}, falling back to SSE transport.")]
private partial void LogStreamableHttpFailed(string endpointName, System.Net.HttpStatusCode statusCode);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Streamable HTTP transport failed with exception, falling back to SSE transport.")]
private partial void LogStreamableHttpException(string endpointName, Exception exception);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Using Streamable HTTP transport.")]
private partial void LogUsingStreamableHttp(string endpointName);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Attempting to connect using SSE transport.")]
private partial void LogAttemptingSSE(string endpointName);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Using SSE transport.")]
private partial void LogUsingSSE(string endpointName);

[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName}: Failed to connect using both Streamable HTTP and SSE transports.")]
private partial void LogSSEConnectionFailed(string endpointName, Exception exception);

[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName}: Error disposing transport.")]
private partial void LogDisposeFailed(string endpointName, Exception exception);
}
149 changes: 149 additions & 0 deletions src/ModelContextProtocol/Client/DelegatingChannelReader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Channels;

namespace ModelContextProtocol.Client;

/// <summary>
/// A <see cref="ChannelReader{T}"/> implementation that delegates to another reader
/// after a connection has been established.
/// </summary>
/// <typeparam name="T">The type of data in the channel.</typeparam>
internal sealed class DelegatingChannelReader<T> : ChannelReader<T>
{
private readonly TaskCompletionSource<bool> _connectionEstablished;
private readonly AutoDetectingClientSessionTransport _parent;

public DelegatingChannelReader(AutoDetectingClientSessionTransport parent)
{
_parent = parent;
_connectionEstablished = new TaskCompletionSource<bool>();
}

/// <summary>
/// Signals that the transport has been established and operations can proceed.
/// </summary>
public void SetConnected()
{
_connectionEstablished.TrySetResult(true);
}

/// <summary>
/// Sets the error if connection couldn't be established.
/// </summary>
public void SetError(Exception exception)
{
_connectionEstablished.TrySetException(exception);
}

/// <summary>
/// Gets the channel reader to delegate to.
/// </summary>
private ChannelReader<T> GetReader()
{
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
throw new InvalidOperationException("Transport connection not yet established.");
}

return (_parent.ActiveTransport?.MessageReader as ChannelReader<T>)!;
}

#if !NETSTANDARD2_0
/// <inheritdoc/>
public override bool CanCount => GetReader().CanCount;

/// <inheritdoc/>
public override bool CanPeek => GetReader().CanPeek;

/// <inheritdoc/>
public override int Count => GetReader().Count;
#endif

/// <inheritdoc/>
public override bool TryPeek(out T item)
{
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
item = default!;
return false;
}

return GetReader().TryPeek(out item!);
}

/// <inheritdoc/>
public override bool TryRead(out T item)
{
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
item = default!;
return false;
}

return GetReader().TryRead(out item!);
}

/// <inheritdoc/>
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
{
// First wait for the connection to be established
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
return new ValueTask<bool>(WaitForConnectionAndThenReadAsync(cancellationToken));
}

// Then delegate to the active reader
return GetReader().WaitToReadAsync(cancellationToken);
}

private async Task<bool> WaitForConnectionAndThenReadAsync(CancellationToken cancellationToken)
{
await _connectionEstablished.Task.ConfigureAwait(false);
return await GetReader().WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
public override ValueTask<T> ReadAsync(CancellationToken cancellationToken = default)
{
// First wait for the connection to be established
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
return new ValueTask<T>(WaitForConnectionAndThenGetItemAsync(cancellationToken));
}

// Then delegate to the active reader
return GetReader().ReadAsync(cancellationToken);
}

private async Task<T> WaitForConnectionAndThenGetItemAsync(CancellationToken cancellationToken)
{
await _connectionEstablished.Task.ConfigureAwait(false);
return await GetReader().ReadAsync(cancellationToken).ConfigureAwait(false);
}

#if NETSTANDARD2_0
public IAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default)
{
// Create a simple async enumerable implementation
async IAsyncEnumerable<T> ReadAllAsyncImplementation()
{
while (await WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
while (TryRead(out var item))
{
yield return item;
}
}
}

return ReadAllAsyncImplementation();
}
#else
/// <inheritdoc/>
public override IAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default)
{
return base.ReadAllAsync(cancellationToken);
}
#endif
}
23 changes: 23 additions & 0 deletions src/ModelContextProtocol/Client/HttpTransportMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace ModelContextProtocol.Client;

/// <summary>
/// Specifies the transport mode for HTTP client connections.
/// </summary>
public enum HttpTransportMode
{
/// <summary>
/// Automatically detect the appropriate transport by trying Streamable HTTP first, then falling back to SSE if that fails.
/// This is the recommended mode for maximum compatibility.
/// </summary>
AutoDetect,

/// <summary>
/// Use only the Streamable HTTP transport.
/// </summary>
StreamableHttp,

/// <summary>
/// Use only the HTTP with SSE transport.
/// </summary>
Sse
}
17 changes: 15 additions & 2 deletions src/ModelContextProtocol/Client/SseClientTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,24 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient
/// <inheritdoc />
public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default)
{
if (_options.UseStreamableHttp)
switch (_options.TransportMode)
{
return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name);
default:
throw new ArgumentException($"Unsupported transport mode: {_options.TransportMode}", nameof(_options.TransportMode));

case HttpTransportMode.AutoDetect:
return new AutoDetectingClientSessionTransport(_options, _httpClient, _loggerFactory, Name);

case HttpTransportMode.StreamableHttp:
return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name);

case HttpTransportMode.Sse:
return await ConnectSseTransportAsync(cancellationToken).ConfigureAwait(false);
}
}

private async Task<ITransport> ConnectSseTransportAsync(CancellationToken cancellationToken)
{
var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name);

try
Expand Down
Loading