Skip to content

Commit c5e37fb

Browse files
added tests
1 parent 0722799 commit c5e37fb

7 files changed

Lines changed: 302 additions & 8 deletions

File tree

src/ModelContextProtocol/McpEndpointExtensions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ public static class McpEndpointExtensions
99
/// <summary>
1010
/// Notifies the connected endpoint of an event.
1111
/// </summary>
12-
/// <param name="endpoint">The endpoint issueing the notification.</param>
12+
/// <param name="endpoint">The endpoint issuing the notification.</param>
1313
/// <param name="notification">The notification to send.</param>
1414
/// <param name="cancellationToken">A token to cancel the operation.</param>
1515
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
@@ -27,7 +27,7 @@ public static Task NotifyAsync(
2727
/// <summary>
2828
/// Notifies the connected endpoint of an event.
2929
/// </summary>
30-
/// <param name="endpoint">The endpoint issueing the notification.</param>
30+
/// <param name="endpoint">The endpoint issuing the notification.</param>
3131
/// <param name="method">The method to call.</param>
3232
/// <param name="parameters">The parameters to send.</param>
3333
/// <param name="cancellationToken">A token to cancel the operation.</param>
@@ -49,7 +49,7 @@ public static Task NotifyAsync(
4949
}
5050

5151
/// <summary>Notifies the connected endpoint of progress.</summary>
52-
/// <param name="endpoint">The endpoint issueing the notification.</param>
52+
/// <param name="endpoint">The endpoint issuing the notification.</param>
5353
/// <param name="progressToken">The <see cref="ProgressToken"/> identifying the operation.</param>
5454
/// <param name="progress">The progress update to send.</param>
5555
/// <param name="cancellationToken">A token to cancel the operation.</param>
@@ -75,7 +75,7 @@ public static Task NotifyProgressAsync(
7575
/// <summary>
7676
/// Notifies the connected endpoint that a request has been cancelled.
7777
/// </summary>
78-
/// <param name="endpoint">The endpoint issueing the notification.</param>
78+
/// <param name="endpoint">The endpoint issuing the notification.</param>
7979
/// <param name="requestId">The ID of the request to cancel.</param>
8080
/// <param name="reason">An optional reason for the cancellation.</param>
8181
/// <param name="cancellationToken">A token to cancel the operation.</param>
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using ModelContextProtocol.Protocol.Messages;
2+
3+
namespace ModelContextProtocol.Shared;
4+
5+
/// <summary>
6+
/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers.
7+
/// </summary>
8+
public interface IMcpSession : IDisposable
9+
{
10+
/// <summary>
11+
/// The name of the endpoint for logging and debug purposes.
12+
/// </summary>
13+
string EndpointName { get; set; }
14+
15+
/// <summary>
16+
/// Starts processing messages from the transport. This method will block until the transport is disconnected.
17+
/// This is generally started in a background task or thread from the initialization logic of the derived class.
18+
/// </summary>
19+
Task ProcessMessagesAsync(CancellationToken cancellationToken);
20+
21+
/// <summary>
22+
/// Sends a generic JSON-RPC request to the server.
23+
/// </summary>
24+
/// <param name="message">The JSON-RPC request to send.</param>
25+
/// <param name="cancellationToken">A token to cancel the operation.</param>
26+
/// <returns>A task containing the server's response.</returns>
27+
Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
28+
29+
/// <summary>
30+
/// Sends a request over the protocol
31+
/// </summary>
32+
/// <typeparam name="TResult">The MCP Response type.</typeparam>
33+
/// <param name="request">The request instance</param>
34+
/// <param name="cancellationToken">The token for cancellation.</param>
35+
/// <returns>The MCP response.</returns>
36+
Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class;
37+
}

src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace ModelContextProtocol.Shared;
1515
/// This is especially true as a client represents a connection to one and only one server, and vice versa.
1616
/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction.
1717
/// </summary>
18-
internal abstract class McpJsonRpcEndpoint : IMcpEndpoint, IAsyncDisposable
18+
public abstract class McpJsonRpcEndpoint : IMcpEndpoint, IAsyncDisposable
1919
{
2020
private readonly RequestHandlers _requestHandlers = [];
2121
private readonly NotificationHandlers _notificationHandlers = [];
@@ -27,6 +27,9 @@ internal abstract class McpJsonRpcEndpoint : IMcpEndpoint, IAsyncDisposable
2727
private readonly SemaphoreSlim _disposeLock = new(1, 1);
2828
private bool _disposed;
2929

30+
/// <summary>
31+
/// The logger for this endpoint.
32+
/// </summary>
3033
protected readonly ILogger _logger;
3134

3235
/// <summary>
@@ -38,18 +41,53 @@ protected McpJsonRpcEndpoint(ILoggerFactory? loggerFactory = null)
3841
_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
3942
}
4043

44+
/// <summary>
45+
/// Sets the request handler for a specific method.
46+
/// </summary>
47+
/// <typeparam name="TRequest">The MCP Request type</typeparam>
48+
/// <typeparam name="TResponse">The MCP Response type</typeparam>
49+
/// <param name="method">The method name.</param>
50+
/// <param name="handler">The handler function.</param>
4151
protected void SetRequestHandler<TRequest, TResponse>(string method, Func<TRequest?, CancellationToken, Task<TResponse>> handler)
4252
=> _requestHandlers.Set(method, handler);
4353

54+
/// <summary>
55+
/// Sets the request handler for a specific method.
56+
/// </summary>
57+
/// <param name="method">The method name.</param>
58+
/// <param name="handler">The handler function.</param>
4459
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler)
4560
=> _notificationHandlers.Add(method, handler);
4661

62+
/// <summary>
63+
/// Sends a request over the protocol
64+
/// </summary>
65+
/// <typeparam name="TResult">The MCP Response type.</typeparam>
66+
/// <param name="request">The request instance</param>
67+
/// <param name="cancellationToken">The token for cancellation.</param>
68+
/// <returns>The MCP response.</returns>
4769
public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class
4870
{
49-
using var registration = cancellationToken.Register(() => _ = this.NotifyCancelAsync(request.Id));
71+
using var registration = cancellationToken.Register(async () =>
72+
{
73+
try
74+
{
75+
await this.NotifyCancelAsync(request.Id).ConfigureAwait(false);
76+
}
77+
catch (Exception ex)
78+
{
79+
_logger.LogError(ex, "An error occurred while notifying cancellation for request {RequestId}.", request.Id);
80+
}
81+
});
5082
return await GetSessionOrThrow().SendRequestAsync<TResult>(request, cancellationToken);
5183
}
5284

85+
/// <summary>
86+
/// Sends a notification over the protocol.
87+
/// </summary>
88+
/// <param name="message">The message to send.</param>
89+
/// <param name="cancellationToken">The token for cancellation.</param>
90+
/// <returns>A task representing the completion of the operation.</returns>
5391
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
5492
=> GetSessionOrThrow().SendMessageAsync(message, cancellationToken);
5593

@@ -63,6 +101,12 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
63101
/// </summary>
64102
protected Task? MessageProcessingTask { get; set; }
65103

104+
/// <summary>
105+
/// Starts the session with the given transport.
106+
/// </summary>
107+
/// <param name="sessionTransport">The transport to use for the session.</param>
108+
/// <param name="fullSessionCancellationToken">A cancellation token for the full session.</param>
109+
/// <exception cref="InvalidOperationException">Thrown if the session has already started.</exception>
66110
[MemberNotNull(nameof(MessageProcessingTask))]
67111
protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken = default)
68112
{
@@ -76,6 +120,10 @@ protected void StartSession(ITransport sessionTransport, CancellationToken fullS
76120
MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token);
77121
}
78122

123+
/// <summary>
124+
/// Disposes the endpoint and releases resources.
125+
/// </summary>
126+
/// <returns>A task representing the completion of the operation.</returns>
79127
public async ValueTask DisposeAsync()
80128
{
81129
using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
@@ -125,6 +173,11 @@ public virtual async ValueTask DisposeUnsynchronizedAsync()
125173
_logger.EndpointCleanedUp(EndpointName);
126174
}
127175

128-
protected McpSession GetSessionOrThrow()
176+
/// <summary>
177+
/// Gets the current session.
178+
/// </summary>
179+
/// <returns>The current session.</returns>
180+
/// <exception cref="InvalidOperationException">Thrown if the session is not started.</exception>
181+
protected IMcpSession GetSessionOrThrow()
129182
=> _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(StartSession)} before sending messages.");
130183
}

src/ModelContextProtocol/Shared/McpSession.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace ModelContextProtocol.Shared;
1414
/// <summary>
1515
/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers.
1616
/// </summary>
17-
internal sealed class McpSession : IDisposable
17+
internal sealed class McpSession : IMcpSession
1818
{
1919
private readonly ITransport _transport;
2020
private readonly RequestHandlers _requestHandlers;
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
using ModelContextProtocol.Protocol.Messages;
2+
using ModelContextProtocol.Protocol.Types;
3+
using ModelContextProtocol.Tests.Utils;
4+
5+
namespace ModelContextProtocol.Tests;
6+
7+
/// <summary>
8+
/// Tests for the cancelled notifications against an IMcpEndpoint.
9+
/// </summary>
10+
public class CancelledNotificationTests(
11+
McpEndpointTestFixture fixture, ITestOutputHelper testOutputHelper)
12+
: LoggedTest(testOutputHelper), IClassFixture<McpEndpointTestFixture>
13+
{
14+
[Fact]
15+
public async Task NotifyCancelAsync_SendsCorrectNotification()
16+
{
17+
// Arrange
18+
await using var endpoint = fixture.CreateEndpoint();
19+
await using var transport = fixture.CreateTransport();
20+
var cancellationToken = TestContext.Current.CancellationToken;
21+
endpoint.Start(transport, cancellationToken);
22+
23+
var requestId = new RequestId("test-request-id-123");
24+
const string reason = "Operation was cancelled by the user";
25+
26+
// Act
27+
await endpoint.NotifyCancelAsync(requestId, reason, cancellationToken);
28+
29+
// Assert
30+
Assert.Single(transport.SentMessages);
31+
var notification = Assert.IsType<JsonRpcNotification>(transport.SentMessages[0]);
32+
Assert.Equal(NotificationMethods.CancelledNotification, notification.Method);
33+
34+
var cancelParams = Assert.IsType<CancelledNotification>(notification.Params);
35+
Assert.Equal(requestId, cancelParams.RequestId);
36+
Assert.Equal(reason, cancelParams.Reason);
37+
}
38+
39+
[Fact]
40+
public async Task SendRequestAsync_Cancellation_SendsNotification()
41+
{
42+
// Arrange
43+
await using var endpoint = fixture.CreateEndpoint();
44+
await using var transport = fixture.CreateTransport();
45+
endpoint.Start(transport, CancellationToken.None);
46+
47+
var requestId = new RequestId("test-request-id-123");
48+
JsonRpcRequest request = new()
49+
{
50+
Id = requestId,
51+
Method = "test.method",
52+
Params = new { },
53+
};
54+
using CancellationTokenSource cancellationSource = new();
55+
await cancellationSource.CancelAsync();
56+
// Act
57+
try
58+
{
59+
await endpoint.SendRequestAsync<EmptyResult>(request, cancellationSource.Token);
60+
}
61+
catch (OperationCanceledException)
62+
{
63+
// Expected exception
64+
}
65+
catch (Exception ex)
66+
{
67+
Assert.Fail($"Unexpected exception: {ex.Message}");
68+
}
69+
70+
// Assert
71+
Assert.NotEmpty(transport.SentMessages);
72+
Assert.Equal(2, transport.SentMessages.Count);
73+
var notification = Assert.IsType<JsonRpcNotification>(transport.SentMessages[0]);
74+
Assert.Equal(NotificationMethods.CancelledNotification, notification.Method);
75+
76+
var cancelParams = Assert.IsType<CancelledNotification>(notification.Params);
77+
Assert.Equal(requestId, cancelParams.RequestId);
78+
79+
var requestMessage = Assert.IsType<JsonRpcRequest>(transport.SentMessages[1]);
80+
Assert.Equal(request.Id, requestMessage.Id);
81+
Assert.Equal(request.Method, requestMessage.Method);
82+
Assert.Equal(request.Params, requestMessage.Params);
83+
}
84+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using Microsoft.Extensions.Logging;
2+
using ModelContextProtocol.Protocol.Messages;
3+
using ModelContextProtocol.Shared;
4+
using ModelContextProtocol.Protocol.Transport;
5+
using System.Threading.Channels;
6+
7+
namespace ModelContextProtocol.Tests;
8+
9+
/// <summary>
10+
/// Test fixture for McpEndpoint tests that provides shared transport implementations.
11+
/// </summary>
12+
public class McpEndpointTestFixture() : IAsyncDisposable
13+
{
14+
private readonly ILoggerFactory _loggerFactory = LoggerFactory.Create(builder => builder.AddConsole());
15+
16+
/// <summary>
17+
/// Creates a test transport.
18+
/// </summary>
19+
internal TestCancellationTransport CreateTransport() => new();
20+
21+
/// <summary>
22+
/// Creates a test client endpoint.
23+
/// </summary>
24+
internal TestMcpJsonRpcEndpoint CreateEndpoint() => new();
25+
26+
27+
public ValueTask DisposeAsync() => ValueTask.CompletedTask;
28+
29+
internal class TestCancellationTransport : ITransport
30+
{
31+
public bool IsConnected => true;
32+
public List<IJsonRpcMessage> SentMessages { get; } = [];
33+
public ChannelReader<IJsonRpcMessage> MessageReader { get; init; }
34+
= Channel.CreateUnbounded<IJsonRpcMessage>().Reader;
35+
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken)
36+
{
37+
SentMessages.Add(message);
38+
return Task.CompletedTask;
39+
}
40+
41+
public ValueTask DisposeAsync() => ValueTask.CompletedTask;
42+
}
43+
44+
internal class TestMcpJsonRpcEndpoint(LoggerFactory? loggerFactory = null)
45+
: McpJsonRpcEndpoint(loggerFactory ?? new())
46+
{
47+
public override string EndpointName => "TestEndpoint";
48+
49+
public void Start(
50+
ITransport transport,
51+
CancellationToken token)
52+
=> StartSession(transport, token);
53+
}
54+
}

0 commit comments

Comments
 (0)