Skip to content

Commit 0722799

Browse files
Make all requests invoke MCP cancellable when cancellation requested.
1 parent 25bcb44 commit 0722799

File tree

3 files changed

+79
-10
lines changed

3 files changed

+79
-10
lines changed

src/ModelContextProtocol/McpEndpointExtensions.cs

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,48 @@ namespace ModelContextProtocol;
66
/// <summary>Provides extension methods for interacting with an <see cref="IMcpEndpoint"/>.</summary>
77
public static class McpEndpointExtensions
88
{
9+
/// <summary>
10+
/// Notifies the connected endpoint of an event.
11+
/// </summary>
12+
/// <param name="endpoint">The endpoint issueing the notification.</param>
13+
/// <param name="notification">The notification to send.</param>
14+
/// <param name="cancellationToken">A token to cancel the operation.</param>
15+
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
16+
/// <returns>A task representing the completion of the operation.</returns>
17+
public static Task NotifyAsync(
18+
this IMcpEndpoint endpoint,
19+
JsonRpcNotification notification,
20+
CancellationToken cancellationToken = default)
21+
{
22+
Throw.IfNull(endpoint);
23+
24+
return endpoint.SendMessageAsync(notification, cancellationToken);
25+
}
26+
27+
/// <summary>
28+
/// Notifies the connected endpoint of an event.
29+
/// </summary>
30+
/// <param name="endpoint">The endpoint issueing the notification.</param>
31+
/// <param name="method">The method to call.</param>
32+
/// <param name="parameters">The parameters to send.</param>
33+
/// <param name="cancellationToken">A token to cancel the operation.</param>
34+
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
35+
/// <returns>A task representing the completion of the operation.</returns>
36+
public static Task NotifyAsync(
37+
this IMcpEndpoint endpoint,
38+
string method,
39+
object? parameters = null,
40+
CancellationToken cancellationToken = default)
41+
{
42+
Throw.IfNull(endpoint);
43+
44+
return endpoint.NotifyAsync(new()
45+
{
46+
Method = method,
47+
Params = parameters,
48+
}, cancellationToken);
49+
}
50+
951
/// <summary>Notifies the connected endpoint of progress.</summary>
1052
/// <param name="endpoint">The endpoint issueing the notification.</param>
1153
/// <param name="progressToken">The <see cref="ProgressToken"/> identifying the operation.</param>
@@ -21,14 +63,38 @@ public static Task NotifyProgressAsync(
2163
{
2264
Throw.IfNull(endpoint);
2365

24-
return endpoint.SendMessageAsync(new JsonRpcNotification()
25-
{
26-
Method = NotificationMethods.ProgressNotification,
27-
Params = new ProgressNotification()
66+
return endpoint.NotifyAsync(
67+
NotificationMethods.ProgressNotification,
68+
new ProgressNotification()
2869
{
2970
ProgressToken = progressToken,
3071
Progress = progress,
31-
},
32-
}, cancellationToken);
72+
}, cancellationToken);
73+
}
74+
75+
/// <summary>
76+
/// Notifies the connected endpoint that a request has been cancelled.
77+
/// </summary>
78+
/// <param name="endpoint">The endpoint issueing the notification.</param>
79+
/// <param name="requestId">The ID of the request to cancel.</param>
80+
/// <param name="reason">An optional reason for the cancellation.</param>
81+
/// <param name="cancellationToken">A token to cancel the operation.</param>
82+
/// <returns>A task representing the completion of the operation.</returns>
83+
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
84+
public static Task NotifyCancelAsync(
85+
this IMcpEndpoint endpoint,
86+
RequestId requestId,
87+
string? reason = null,
88+
CancellationToken cancellationToken = default)
89+
{
90+
Throw.IfNull(endpoint);
91+
92+
return endpoint.NotifyAsync(
93+
NotificationMethods.CancelledNotification,
94+
new CancelledNotification()
95+
{
96+
RequestId = requestId,
97+
Reason = reason,
98+
}, cancellationToken);
3399
}
34100
}

src/ModelContextProtocol/Protocol/Messages/CancelledNotification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public sealed class CancelledNotification
1111
/// The ID of the request to cancel.
1212
/// </summary>
1313
[JsonPropertyName("requestId")]
14-
public RequestId RequestId { get; set; }
14+
public required RequestId RequestId { get; set; }
1515

1616
/// <summary>
1717
/// An optional string describing the reason for the cancellation.

src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace ModelContextProtocol.Shared;
1515
/// This is especially true as a client represents a connection to one and only one server, and vice versa.
1616
/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction.
1717
/// </summary>
18-
internal abstract class McpJsonRpcEndpoint : IAsyncDisposable
18+
internal abstract class McpJsonRpcEndpoint : IMcpEndpoint, IAsyncDisposable
1919
{
2020
private readonly RequestHandlers _requestHandlers = [];
2121
private readonly NotificationHandlers _notificationHandlers = [];
@@ -44,8 +44,11 @@ protected void SetRequestHandler<TRequest, TResponse>(string method, Func<TReque
4444
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler)
4545
=> _notificationHandlers.Add(method, handler);
4646

47-
public Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class
48-
=> GetSessionOrThrow().SendRequestAsync<TResult>(request, cancellationToken);
47+
public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class
48+
{
49+
using var registration = cancellationToken.Register(() => _ = this.NotifyCancelAsync(request.Id));
50+
return await GetSessionOrThrow().SendRequestAsync<TResult>(request, cancellationToken);
51+
}
4952

5053
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
5154
=> GetSessionOrThrow().SendMessageAsync(message, cancellationToken);

0 commit comments

Comments
 (0)