Skip to content
Closed
2 changes: 1 addition & 1 deletion global.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"sdk": {
"version": "9.0.100",
"version": "8.0.115",
"rollForward": "minor"
}
}
196 changes: 196 additions & 0 deletions src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Protocol;
using System.Diagnostics;
using System.Threading.Channels;

namespace ModelContextProtocol.Client;

/// <summary>
/// A transport that automatically detects whether to use Streamable HTTP or SSE transport
/// by trying Streamable HTTP first and falling back to SSE if that fails.
/// </summary>
internal sealed partial class AutoDetectingClientSessionTransport : ITransport
{
private readonly SseClientTransportOptions _options;
private readonly HttpClient _httpClient;
private readonly ILoggerFactory? _loggerFactory;
private readonly ILogger _logger;
private readonly string _name;
private readonly DelegatingChannelReader<JsonRpcMessage> _delegatingChannelReader;

private StreamableHttpClientSessionTransport? _streamableHttpTransport;
private SseClientSessionTransport? _sseTransport;

public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName)
{
Throw.IfNull(transportOptions);
Throw.IfNull(httpClient);

_options = transportOptions;
_httpClient = httpClient;
_loggerFactory = loggerFactory;
_logger = (ILogger?)loggerFactory?.CreateLogger<AutoDetectingClientSessionTransport>() ?? NullLogger.Instance;
_name = endpointName;
_delegatingChannelReader = new DelegatingChannelReader<JsonRpcMessage>(this);
}

/// <summary>
/// Returns the active transport (either StreamableHttp or SSE)
/// </summary>
internal ITransport? ActiveTransport => _streamableHttpTransport != null ? (ITransport)_streamableHttpTransport : _sseTransport;

public ChannelReader<JsonRpcMessage> MessageReader => _delegatingChannelReader;

/// <inheritdoc/>
public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (_streamableHttpTransport == null && _sseTransport == null)
{
var rpcRequest = message as JsonRpcRequest;

// Try StreamableHttp first
_streamableHttpTransport = new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, _name);

try
{
LogAttemptingStreamableHttp(_name);
var response = await _streamableHttpTransport.SendInitialRequestAsync(message, cancellationToken).ConfigureAwait(false);

// If the status code is not success, fall back to SSE
if (!response.IsSuccessStatusCode)
{
LogStreamableHttpFailed(_name, response.StatusCode);

try
{
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
}
finally
{
_streamableHttpTransport = null;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't bother nulling out _streamableHttpTransport except when nothing throws. We only need to handle the common, non-200 status code failure. Any other failure we should just let the exception bubble up and dispose any transports, but we don't need to assign them to null except in the normal IsSuccessStatusCode case when ._streamableHttpTransport.DisposeAsync() doesn't throw.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the code to only null out _streamableHttpTransport in the successful non-exception case in commit bac8a6b.

await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false);
}
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try
{
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
}
finally
{
_streamableHttpTransport = null;
await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false);
}
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false);

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the code to remove unnecessary try/finally block in commit 6df9961.

return;
}

// Process the streamable HTTP response using the transport
await _streamableHttpTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);

// Signal that we have established a connection
LogUsingStreamableHttp(_name);
_delegatingChannelReader.SetConnected();
}
catch (Exception ex)
{
LogStreamableHttpException(_name, ex);

try
{
if (_streamableHttpTransport != null)
{
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
_streamableHttpTransport = null;
}
}
catch (Exception disposeEx)
{
LogDisposeFailed(_name, disposeEx);
}
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LogStreamableHttpException(_name, ex);
try
{
if (_streamableHttpTransport != null)
{
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
_streamableHttpTransport = null;
}
}
catch (Exception disposeEx)
{
LogDisposeFailed(_name, disposeEx);
}
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the redundant code to simplify the exception handling in commit 6df9961.


// Propagate the original exception
throw;
}
}
else if (_streamableHttpTransport != null)
{
await _streamableHttpTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
else if (_sseTransport != null)
{
await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
}

private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken)
{
Exception? capturedEx = null;
try
{
LogAttemptingSSE(_name);
_sseTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, _name);
await _sseTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);

// Signal that we have established a connection
LogUsingSSE(_name);
_delegatingChannelReader.SetConnected();
}
catch (Exception ex)
{
LogSSEConnectionFailed(_name, ex);
capturedEx = ex;

try
{
if (_sseTransport != null)
{
await _sseTransport.DisposeAsync().ConfigureAwait(false);
}
}
finally
{
// Set the error so the channel reader will propagate it
_delegatingChannelReader.SetError(ex);
}

throw;
}
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Exception? capturedEx = null;
try
{
LogAttemptingSSE(_name);
_sseTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, _name);
await _sseTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
// Signal that we have established a connection
LogUsingSSE(_name);
_delegatingChannelReader.SetConnected();
}
catch (Exception ex)
{
LogSSEConnectionFailed(_name, ex);
capturedEx = ex;
try
{
if (_sseTransport != null)
{
await _sseTransport.DisposeAsync().ConfigureAwait(false);
}
}
finally
{
// Set the error so the channel reader will propagate it
_delegatingChannelReader.SetError(ex);
}
throw;
}
_sseTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, _name);
try
{
LogAttemptingSSE(_name);
await _sseTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
// Signal that we have established a connection
LogUsingSSE(_name);
_delegatingChannelReader.SetConnected();
}
catch (Exception ex)
{
LogSSEConnectionFailed(_name, ex);
_delegatingChannelReader.SetError(ex);
await _sseTransport.DisposeAsync().ConfigureAwait(false);
throw;
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the SSE transport initialization in commit 6df9961.

}

public async ValueTask DisposeAsync()
{
try
{
if (_streamableHttpTransport != null)
{
await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false);
_streamableHttpTransport = null;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_streamableHttpTransport = null;

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the unnecessary null assignment in commit 6df9961.

}

if (_sseTransport != null)
{
await _sseTransport.DisposeAsync().ConfigureAwait(false);
_sseTransport = null;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_sseTransport = null;

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the unnecessary null assignment in commit 6df9961.

}
}
catch (Exception ex)
{
LogDisposeFailed(_name, ex);
}
}

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Attempting to connect using Streamable HTTP transport")]
private partial void LogAttemptingStreamableHttp(string endpointName);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Streamable HTTP transport failed with status code {StatusCode}, falling back to SSE transport")]
private partial void LogStreamableHttpFailed(string endpointName, System.Net.HttpStatusCode statusCode);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Streamable HTTP transport failed with exception, falling back to SSE transport")]
private partial void LogStreamableHttpException(string endpointName, Exception exception);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Using Streamable HTTP transport")]
private partial void LogUsingStreamableHttp(string endpointName);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Attempting to connect using SSE transport")]
private partial void LogAttemptingSSE(string endpointName);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Using SSE transport")]
private partial void LogUsingSSE(string endpointName);

[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName}: SSE transport connection failed")]
private partial void LogSSEConnectionFailed(string endpointName, Exception exception);
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName}: SSE transport connection failed")]
private partial void LogSSEConnectionFailed(string endpointName, Exception exception);
[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName}: Failed to connect using both Streamable HTTP and SSE transports.")]
private partial void LogAutoDetectingConnectionFailed(string endpointName, Exception exception);

Also, add periods to the end of all the log messages.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the log message and added periods to all log messages in commit 6df9961.


[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName}: Error disposing transport")]
private partial void LogDisposeFailed(string endpointName, Exception exception);
}
153 changes: 153 additions & 0 deletions src/ModelContextProtocol/Client/DelegatingChannelReader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Channels;

namespace ModelContextProtocol.Client;

/// <summary>
/// A <see cref="ChannelReader{T}"/> implementation that delegates to another reader
/// after a connection has been established.
/// </summary>
/// <typeparam name="T">The type of data in the channel.</typeparam>
internal sealed class DelegatingChannelReader<T> : ChannelReader<T>
{
private readonly TaskCompletionSource<bool> _connectionEstablished;
private readonly AutoDetectingClientSessionTransport _parent;

public DelegatingChannelReader(AutoDetectingClientSessionTransport parent)
{
_parent = parent;
_connectionEstablished = new TaskCompletionSource<bool>();
}

/// <summary>
/// Signals that the transport has been established and operations can proceed.
/// </summary>
public void SetConnected()
{
_connectionEstablished.TrySetResult(true);
}

/// <summary>
/// Sets the error if connection couldn't be established.
/// </summary>
public void SetError(Exception exception)
{
_connectionEstablished.TrySetException(exception);
}

/// <summary>
/// Gets the channel reader to delegate to.
/// </summary>
private ChannelReader<T> GetReader()
{
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
throw new InvalidOperationException("Transport connection not yet established.");
}

return (_parent.ActiveTransport?.MessageReader as ChannelReader<T>)!;
}

#if !NETSTANDARD2_0
/// <inheritdoc/>
public override bool CanCount => GetReader().CanCount;

/// <inheritdoc/>
public override bool CanPeek => GetReader().CanPeek;

/// <inheritdoc/>
public override int Count => GetReader().Count;
#endif

/// <inheritdoc/>
public override bool TryPeek(out T item)
{
try
{
return GetReader().TryPeek(out item!);
}
catch (InvalidOperationException)
{
item = default!;
return false;
}
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try
{
return GetReader().TryPeek(out item!);
}
catch (InvalidOperationException)
{
item = default!;
return false;
}
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
item = default!;
return false;
}
return GetReader().TryPeek(out item!);

We shouldn't catch an InvalidOperationException. Do this instead. Make the same change to TryRead

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the implementation to check task status directly instead of catching exceptions in commit 6df9961.

}

/// <inheritdoc/>
public override bool TryRead(out T item)
{
try
{
return GetReader().TryRead(out item!);
}
catch (InvalidOperationException)
{
item = default!;
return false;
}
}

/// <inheritdoc/>
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
{
// First wait for the connection to be established
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
return new ValueTask<bool>(WaitForConnectionAndThenReadAsync(cancellationToken));
}

// Then delegate to the active reader
return GetReader().WaitToReadAsync(cancellationToken);
}

private async Task<bool> WaitForConnectionAndThenReadAsync(CancellationToken cancellationToken)
{
await _connectionEstablished.Task.ConfigureAwait(false);
return await GetReader().WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
public override ValueTask<T> ReadAsync(CancellationToken cancellationToken = default)
{
// First wait for the connection to be established
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
return new ValueTask<T>(WaitForConnectionAndThenGetItemAsync(cancellationToken));
}

// Then delegate to the active reader
return GetReader().ReadAsync(cancellationToken);
}

private async Task<T> WaitForConnectionAndThenGetItemAsync(CancellationToken cancellationToken)
{
await _connectionEstablished.Task.ConfigureAwait(false);
return await GetReader().ReadAsync(cancellationToken).ConfigureAwait(false);
}

#if NETSTANDARD2_0
public IAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default)
{
// Create a simple async enumerable implementation
async IAsyncEnumerable<T> ReadAllAsyncImplementation()
{
while (await WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
while (TryRead(out var item))
{
yield return item;
}
}
}

return ReadAllAsyncImplementation();
}
#else
/// <inheritdoc/>
public override IAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default)
{
return base.ReadAllAsync(cancellationToken);
}
#endif
}
23 changes: 23 additions & 0 deletions src/ModelContextProtocol/Client/HttpTransportMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace ModelContextProtocol.Client;

/// <summary>
/// Specifies the transport mode for HTTP client connections.
/// </summary>
public enum HttpTransportMode
{
/// <summary>
/// Automatically detect the appropriate transport by trying Streamable HTTP first, then falling back to SSE if that fails.
/// This is the recommended mode for maximum compatibility.
/// </summary>
AutoDetect,

/// <summary>
/// Use only the Streamable HTTP transport.
/// </summary>
StreamableHttp,

/// <summary>
/// Use only the HTTP with SSE transport.
/// </summary>
Sse
}
Loading
Loading