diff --git a/global.json b/global.json index 7edec3849..fdaac1bc6 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "9.0.100", + "version": "8.0.115", "rollForward": "minor" } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs new file mode 100644 index 000000000..2d29c61b6 --- /dev/null +++ b/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs @@ -0,0 +1,163 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Diagnostics; +using System.Threading.Channels; + +namespace ModelContextProtocol.Client; + +/// +/// A transport that automatically detects whether to use Streamable HTTP or SSE transport +/// by trying Streamable HTTP first and falling back to SSE if that fails. +/// +internal sealed partial class AutoDetectingClientSessionTransport : ITransport +{ + private readonly SseClientTransportOptions _options; + private readonly HttpClient _httpClient; + private readonly ILoggerFactory? _loggerFactory; + private readonly ILogger _logger; + private readonly string _name; + private readonly DelegatingChannelReader _delegatingChannelReader; + + private StreamableHttpClientSessionTransport? _streamableHttpTransport; + private SseClientSessionTransport? _sseTransport; + + public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) + { + Throw.IfNull(transportOptions); + Throw.IfNull(httpClient); + + _options = transportOptions; + _httpClient = httpClient; + _loggerFactory = loggerFactory; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _name = endpointName; + _delegatingChannelReader = new DelegatingChannelReader(this); + } + + /// + /// Returns the active transport (either StreamableHttp or SSE) + /// + internal ITransport? ActiveTransport => _streamableHttpTransport != null ? (ITransport)_streamableHttpTransport : _sseTransport; + + public ChannelReader MessageReader => _delegatingChannelReader; + + /// + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (_streamableHttpTransport == null && _sseTransport == null) + { + var rpcRequest = message as JsonRpcRequest; + + // Try StreamableHttp first + _streamableHttpTransport = new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, _name); + + try + { + LogAttemptingStreamableHttp(_name); + var response = await _streamableHttpTransport.SendInitialRequestAsync(message, cancellationToken).ConfigureAwait(false); + + // If the status code is not success, fall back to SSE + if (!response.IsSuccessStatusCode) + { + LogStreamableHttpFailed(_name, response.StatusCode); + + await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false); + await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false); + return; + } + + // Process the streamable HTTP response using the transport + await _streamableHttpTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + + // Signal that we have established a connection + LogUsingStreamableHttp(_name); + _delegatingChannelReader.SetConnected(); + } + catch (Exception ex) + { + LogStreamableHttpException(_name, ex); + + await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false); + + // Propagate the original exception + throw; + } + } + else if (_streamableHttpTransport != null) + { + await _streamableHttpTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + else if (_sseTransport != null) + { + await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + } + + private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) + { + _sseTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, _name); + + try + { + LogAttemptingSSE(_name); + await _sseTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + await _sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + + // Signal that we have established a connection + LogUsingSSE(_name); + _delegatingChannelReader.SetConnected(); + } + catch (Exception ex) + { + LogSSEConnectionFailed(_name, ex); + _delegatingChannelReader.SetError(ex); + await _sseTransport.DisposeAsync().ConfigureAwait(false); + throw; + } + } + + public async ValueTask DisposeAsync() + { + try + { + if (_streamableHttpTransport != null) + { + await _streamableHttpTransport.DisposeAsync().ConfigureAwait(false); + } + + if (_sseTransport != null) + { + await _sseTransport.DisposeAsync().ConfigureAwait(false); + } + } + catch (Exception ex) + { + LogDisposeFailed(_name, ex); + } + } + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Attempting to connect using Streamable HTTP transport.")] + private partial void LogAttemptingStreamableHttp(string endpointName); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Streamable HTTP transport failed with status code {StatusCode}, falling back to SSE transport.")] + private partial void LogStreamableHttpFailed(string endpointName, System.Net.HttpStatusCode statusCode); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Streamable HTTP transport failed with exception, falling back to SSE transport.")] + private partial void LogStreamableHttpException(string endpointName, Exception exception); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Using Streamable HTTP transport.")] + private partial void LogUsingStreamableHttp(string endpointName); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Attempting to connect using SSE transport.")] + private partial void LogAttemptingSSE(string endpointName); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName}: Using SSE transport.")] + private partial void LogUsingSSE(string endpointName); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName}: Failed to connect using both Streamable HTTP and SSE transports.")] + private partial void LogSSEConnectionFailed(string endpointName, Exception exception); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName}: Error disposing transport.")] + private partial void LogDisposeFailed(string endpointName, Exception exception); +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/DelegatingChannelReader.cs b/src/ModelContextProtocol/Client/DelegatingChannelReader.cs new file mode 100644 index 000000000..1aa8427fa --- /dev/null +++ b/src/ModelContextProtocol/Client/DelegatingChannelReader.cs @@ -0,0 +1,149 @@ +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading.Channels; + +namespace ModelContextProtocol.Client; + +/// +/// A implementation that delegates to another reader +/// after a connection has been established. +/// +/// The type of data in the channel. +internal sealed class DelegatingChannelReader : ChannelReader +{ + private readonly TaskCompletionSource _connectionEstablished; + private readonly AutoDetectingClientSessionTransport _parent; + + public DelegatingChannelReader(AutoDetectingClientSessionTransport parent) + { + _parent = parent; + _connectionEstablished = new TaskCompletionSource(); + } + + /// + /// Signals that the transport has been established and operations can proceed. + /// + public void SetConnected() + { + _connectionEstablished.TrySetResult(true); + } + + /// + /// Sets the error if connection couldn't be established. + /// + public void SetError(Exception exception) + { + _connectionEstablished.TrySetException(exception); + } + + /// + /// Gets the channel reader to delegate to. + /// + private ChannelReader GetReader() + { + if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion) + { + throw new InvalidOperationException("Transport connection not yet established."); + } + + return (_parent.ActiveTransport?.MessageReader as ChannelReader)!; + } + +#if !NETSTANDARD2_0 + /// + public override bool CanCount => GetReader().CanCount; + + /// + public override bool CanPeek => GetReader().CanPeek; + + /// + public override int Count => GetReader().Count; +#endif + + /// + public override bool TryPeek(out T item) + { + if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion) + { + item = default!; + return false; + } + + return GetReader().TryPeek(out item!); + } + + /// + public override bool TryRead(out T item) + { + if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion) + { + item = default!; + return false; + } + + return GetReader().TryRead(out item!); + } + + /// + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken = default) + { + // First wait for the connection to be established + if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion) + { + return new ValueTask(WaitForConnectionAndThenReadAsync(cancellationToken)); + } + + // Then delegate to the active reader + return GetReader().WaitToReadAsync(cancellationToken); + } + + private async Task WaitForConnectionAndThenReadAsync(CancellationToken cancellationToken) + { + await _connectionEstablished.Task.ConfigureAwait(false); + return await GetReader().WaitToReadAsync(cancellationToken).ConfigureAwait(false); + } + + /// + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + // First wait for the connection to be established + if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion) + { + return new ValueTask(WaitForConnectionAndThenGetItemAsync(cancellationToken)); + } + + // Then delegate to the active reader + return GetReader().ReadAsync(cancellationToken); + } + + private async Task WaitForConnectionAndThenGetItemAsync(CancellationToken cancellationToken) + { + await _connectionEstablished.Task.ConfigureAwait(false); + return await GetReader().ReadAsync(cancellationToken).ConfigureAwait(false); + } + +#if NETSTANDARD2_0 + public IAsyncEnumerable ReadAllAsync(CancellationToken cancellationToken = default) + { + // Create a simple async enumerable implementation + async IAsyncEnumerable ReadAllAsyncImplementation() + { + while (await WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (TryRead(out var item)) + { + yield return item; + } + } + } + + return ReadAllAsyncImplementation(); + } +#else + /// + public override IAsyncEnumerable ReadAllAsync(CancellationToken cancellationToken = default) + { + return base.ReadAllAsync(cancellationToken); + } +#endif +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/HttpTransportMode.cs b/src/ModelContextProtocol/Client/HttpTransportMode.cs new file mode 100644 index 000000000..f2d46c302 --- /dev/null +++ b/src/ModelContextProtocol/Client/HttpTransportMode.cs @@ -0,0 +1,23 @@ +namespace ModelContextProtocol.Client; + +/// +/// Specifies the transport mode for HTTP client connections. +/// +public enum HttpTransportMode +{ + /// + /// Automatically detect the appropriate transport by trying Streamable HTTP first, then falling back to SSE if that fails. + /// This is the recommended mode for maximum compatibility. + /// + AutoDetect, + + /// + /// Use only the Streamable HTTP transport. + /// + StreamableHttp, + + /// + /// Use only the HTTP with SSE transport. + /// + Sse +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/SseClientTransport.cs b/src/ModelContextProtocol/Client/SseClientTransport.cs index df1cdac6c..fbbb42e49 100644 --- a/src/ModelContextProtocol/Client/SseClientTransport.cs +++ b/src/ModelContextProtocol/Client/SseClientTransport.cs @@ -57,11 +57,24 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - if (_options.UseStreamableHttp) + switch (_options.TransportMode) { - return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name); + default: + throw new ArgumentException($"Unsupported transport mode: {_options.TransportMode}", nameof(_options.TransportMode)); + + case HttpTransportMode.AutoDetect: + return new AutoDetectingClientSessionTransport(_options, _httpClient, _loggerFactory, Name); + + case HttpTransportMode.StreamableHttp: + return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name); + + case HttpTransportMode.Sse: + return await ConnectSseTransportAsync(cancellationToken).ConfigureAwait(false); } + } + private async Task ConnectSseTransportAsync(CancellationToken cancellationToken) + { var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name); try diff --git a/src/ModelContextProtocol/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol/Client/SseClientTransportOptions.cs index f67f6f07d..b15797260 100644 --- a/src/ModelContextProtocol/Client/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Client/SseClientTransportOptions.cs @@ -10,7 +10,6 @@ public record SseClientTransportOptions /// public required Uri Endpoint { - get; init { if (value is null) @@ -26,16 +25,25 @@ public required Uri Endpoint throw new ArgumentException("Endpoint must use HTTP or HTTPS scheme.", nameof(value)); } - field = value; + _endpoint = value; } + get => _endpoint; } + private Uri _endpoint = null!; + + /// - /// Gets or sets a value indicating whether to use "Streamable HTTP" for the transport rather than "HTTP with SSE". Defaults to false. - /// Streamable HTTP transport specification. - /// HTTP with SSE transport specification. + /// Gets or sets the transport mode to use for the connection. Defaults to . /// - public bool UseStreamableHttp { get; init; } + /// + /// + /// When set to (the default), the client will first attempt to use + /// Streamable HTTP transport and automatically fall back to SSE transport if the server doesn't support it. + /// This provides the best compatibility and matches the behavior of VS Code. + /// + /// + public HttpTransportMode TransportMode { get; init; } = HttpTransportMode.AutoDetect; /// /// Gets a transport identifier used for logging purposes. diff --git a/src/ModelContextProtocol/Client/StdioClientTransportOptions.cs b/src/ModelContextProtocol/Client/StdioClientTransportOptions.cs index afacd3594..0872024bf 100644 --- a/src/ModelContextProtocol/Client/StdioClientTransportOptions.cs +++ b/src/ModelContextProtocol/Client/StdioClientTransportOptions.cs @@ -10,17 +10,18 @@ public record StdioClientTransportOptions /// public required string Command { - get; - set + init { if (string.IsNullOrWhiteSpace(value)) { throw new ArgumentException("Command cannot be null or empty.", nameof(value)); } - field = value; + _command = value; } + get => _command; } + private string _command = null!; /// /// Gets or sets the arguments to pass to the server process when it is started. diff --git a/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs index 55ecb9630..e711f1be8 100644 --- a/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs @@ -50,32 +50,7 @@ public override async Task SendMessageAsync( JsonRpcMessage message, CancellationToken cancellationToken = default) { - using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); - cancellationToken = sendCts.Token; - -#if NET - using var content = JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); -#else - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json; charset=utf-8" - ); -#endif - - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _options.Endpoint) - { - Content = content, - Headers = - { - Accept = { s_applicationJsonMediaType, s_textEventStreamMediaType }, - }, - }; - - CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId); - using var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); - - response.EnsureSuccessStatusCode(); + using var response = await SendHttpRequestInternalAsync(message, cancellationToken); var rpcRequest = message as JsonRpcRequest; JsonRpcMessage? rpcResponseCandidate = null; @@ -197,22 +172,62 @@ private async Task ReceiveUnsolicitedMessagesAsync() } catch (JsonException ex) { - LogJsonException(ex, data); + LogTransportMessageParseFailed(Name, ex); } return null; } - private void LogJsonException(JsonException ex, string data) + /// + /// Sends the initial initialization request and returns the HTTP response so the status code can be checked. + /// + /// The initialize message to send, which must be a JsonRpcRequest with the method "initialize". + /// The cancellation token. + /// The HTTP response message for the initialization request. + internal async Task SendInitialRequestAsync( + JsonRpcMessage message, + CancellationToken cancellationToken = default) { - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogTransportMessageParseFailedSensitive(Name, data, ex); - } - else + return await SendHttpRequestInternalAsync(message, cancellationToken); + } + + /// + /// Internal implementation to send an HTTP request with the specified message. + /// + /// The message to send. + /// The cancellation token. + /// The HTTP response message. + private async Task SendHttpRequestInternalAsync( + JsonRpcMessage message, + CancellationToken cancellationToken) + { + using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); + cancellationToken = sendCts.Token; + +#if NET + using var content = JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); +#else + using var content = new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json; charset=utf-8" + ); +#endif + + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _options.Endpoint) { - LogTransportMessageParseFailed(Name, ex); - } + Content = content, + Headers = + { + Accept = { s_applicationJsonMediaType, s_textEventStreamMediaType }, + }, + }; + + CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId); + + var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + + return response; } internal static void CopyAdditionalHeaders(HttpRequestHeaders headers, Dictionary? additionalHeaders, string? sessionId = null) diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index cdb48e14d..bd87c9ccf 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -1,7 +1,7 @@  - net9.0;net8.0;netstandard2.0 + net8.0;netstandard2.0 true true ModelContextProtocol diff --git a/src/ModelContextProtocol/UriTemplate.cs b/src/ModelContextProtocol/UriTemplate.cs index bc6b70c9f..d8ea12ea8 100644 --- a/src/ModelContextProtocol/UriTemplate.cs +++ b/src/ModelContextProtocol/UriTemplate.cs @@ -443,7 +443,7 @@ static void AppendHex(ref DefaultInterpolatedStringHandler builder, char c) Span utf8 = stackalloc byte[Encoding.UTF8.GetMaxByteCount(1)]; foreach (byte b in utf8.Slice(0, new Rune(c).EncodeToUtf8(utf8))) #else - foreach (byte b in Encoding.UTF8.GetBytes([c])) + foreach (byte b in Encoding.UTF8.GetBytes(c.ToString())) #endif { builder.AppendFormatted('%'); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index c987bca90..35b35cf9c 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -1,5 +1,6 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; namespace ModelContextProtocol.AspNetCore.Tests; @@ -34,4 +35,116 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); } + + [Fact] + public async Task StreamableHttp_Mode_Should_Work_With_Root_Endpoint() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "StreamableHttpTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost/"), + TransportMode = HttpTransportMode.StreamableHttp + }; + + var mcpClient = await ConnectAsync("/", options); + + Assert.Equal("StreamableHttpTestServer", mcpClient.ServerInfo.Name); + } + + [Fact] + public async Task AutoDetect_Mode_Should_Work_With_Root_Endpoint() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "AutoDetectTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost/"), + TransportMode = HttpTransportMode.AutoDetect + }; + + var mcpClient = await ConnectAsync("/", options); + + Assert.Equal("AutoDetectTestServer", mcpClient.ServerInfo.Name); + } + + [Fact] + public async Task AutoDetect_Mode_Should_Work_With_Sse_Endpoint() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "AutoDetectSseTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost/sse"), + TransportMode = HttpTransportMode.AutoDetect + }; + + var mcpClient = await ConnectAsync("/sse", options); + + Assert.Equal("AutoDetectSseTestServer", mcpClient.ServerInfo.Name); + } + + [Fact] + public async Task Sse_Mode_Should_Work_With_Sse_Endpoint() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "SseTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost/sse"), + TransportMode = HttpTransportMode.Sse + }; + + var mcpClient = await ConnectAsync("/sse", options); + + Assert.Equal("SseTestServer", mcpClient.ServerInfo.Name); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index cf49fee16..44ae2e75a 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -20,14 +20,22 @@ protected void ConfigureStateless(HttpServerTransportOptions options) options.Stateless = Stateless; } - protected async Task ConnectAsync(string? path = null) + protected async Task ConnectAsync(string? path = null, SseClientTransportOptions? options = null) { + if (options != null) + { + // When options are provided, use them as-is + await using var transport = new SseClientTransport(options, HttpClient, LoggerFactory); + return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + } + + // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; var sseClientTransportOptions = new SseClientTransportOptions() { Endpoint = new Uri($"http://localhost{path}"), - UseStreamableHttp = UseStreamableHttp, + TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }; await using var transport = new SseClientTransport(sseClientTransportOptions, HttpClient, LoggerFactory); return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index b1b618057..a9e2e5f54 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -9,6 +9,6 @@ public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fix { Endpoint = new Uri("http://localhost/stateless"), Name = "In-memory Streamable HTTP Client", - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }; } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index 2f364be01..acfc744b9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -18,7 +18,7 @@ public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMem { Endpoint = new Uri("http://localhost/"), Name = "In-memory Streamable HTTP Client", - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }; private async Task StartAsync() diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 94540f8c2..d7f8433b3 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -98,7 +98,7 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer() await using var transport = new SseClientTransport(new() { Endpoint = new("http://localhost/mcp"), - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -118,7 +118,7 @@ public async Task CanCallToolConcurrently() await using var transport = new SseClientTransport(new() { Endpoint = new("http://localhost/mcp"), - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 64505b3d9..7c4366f16 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -15,7 +15,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur { Endpoint = new Uri("http://localhost/"), Name = "In-memory Streamable HTTP Client", - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }; [Fact] diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs new file mode 100644 index 000000000..09346cad2 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs @@ -0,0 +1,206 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Tests.Utils; +using System.Net; + +namespace ModelContextProtocol.Tests.Transport; + +public class SseClientTransportAutoDetectTests : LoggedTest +{ + public SseClientTransportAutoDetectTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + [Fact] + public async Task AutoDetect_Should_Use_StreamableHttp_When_Server_Supports_It() + { + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost:8080"), + TransportMode = HttpTransportMode.AutoDetect, + ConnectionTimeout = TimeSpan.FromSeconds(2), + Name = "Test Server" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + + // Simulate successful Streamable HTTP response for initialize + mockHttpHandler.RequestHandler = (request) => + { + if (request.Method == HttpMethod.Post) + { + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("{\"jsonrpc\":\"2.0\",\"id\":\"init-id\",\"result\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{\"tools\":{}}}}"), + Headers = + { + { "Content-Type", "application/json" }, + { "mcp-session-id", "test-session" } + } + }); + } + + // Shouldn't reach here for successful Streamable HTTP + throw new InvalidOperationException("Unexpected request"); + }; + + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + // The auto-detecting transport should be returned + Assert.NotNull(session); + Assert.True(session.IsConnected); + Assert.IsType(session); + } + + [Fact] + public async Task AutoDetect_Should_Fallback_To_Sse_When_StreamableHttp_Fails() + { + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost:8080"), + TransportMode = HttpTransportMode.AutoDetect, + ConnectionTimeout = TimeSpan.FromSeconds(2), + Name = "Test Server" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + + var requestCount = 0; + + mockHttpHandler.RequestHandler = (request) => + { + requestCount++; + + if (request.Method == HttpMethod.Post && requestCount == 1) + { + // First POST (Streamable HTTP) fails + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.NotFound, + Content = new StringContent("Streamable HTTP not supported") + }); + } + + if (request.Method == HttpMethod.Get) + { + // SSE connection request + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("event: endpoint\r\ndata: /sse-endpoint\r\n\r\n"), + Headers = { { "Content-Type", "text/event-stream" } } + }); + } + + if (request.Method == HttpMethod.Post && requestCount > 1) + { + // Subsequent POST to SSE endpoint succeeds + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("accepted") + }); + } + + throw new InvalidOperationException($"Unexpected request: {request.Method}, count: {requestCount}"); + }; + + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + // The auto-detecting transport should be returned + Assert.NotNull(session); + Assert.True(session.IsConnected); + Assert.IsType(session); + } + + [Fact] + public async Task TransportMode_AutoDetect_Should_Use_AutoDetectingTransport() + { + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost:8080"), + TransportMode = HttpTransportMode.AutoDetect, + ConnectionTimeout = TimeSpan.FromSeconds(2), + Name = "Test Server" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + + // Configure for successful Streamable HTTP response + mockHttpHandler.RequestHandler = (request) => + { + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("{\"jsonrpc\":\"2.0\",\"id\":\"test-id\",\"result\":{}}"), + Headers = { { "Content-Type", "application/json" } } + }); + }; + + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + // Should return AutoDetectingClientSessionTransport when using AutoDetect mode + Assert.IsType(session); + } + + [Fact] + public async Task TransportMode_StreamableHttp_Should_Return_StreamableHttp_Transport() + { + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost:8080"), + TransportMode = HttpTransportMode.StreamableHttp, + ConnectionTimeout = TimeSpan.FromSeconds(2), + Name = "Test Server" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + // Should return StreamableHttpClientSessionTransport directly + Assert.IsType(session); + } + + [Fact] + public async Task Sse_Mode_Should_Return_Sse_Transport() + { + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost:8080"), + TransportMode = HttpTransportMode.Sse, + ConnectionTimeout = TimeSpan.FromSeconds(2), + Name = "Test Server" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + + mockHttpHandler.RequestHandler = (request) => + { + // Simulate SSE endpoint response + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("event: endpoint\r\ndata: /sse-endpoint\r\n\r\n"), + Headers = { { "Content-Type", "text/event-stream" } } + }); + }; + + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + // Should return SseClientSessionTransport directly + Assert.IsType(session); + } +} \ No newline at end of file