Skip to content

Commit 6aa993d

Browse files
Tyler-R-Kendrickstephentoub
authored andcommitted
Extend progress notification support
1 parent b12d728 commit 6aa993d

18 files changed

+207
-56
lines changed

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp
4242

4343
SetRequestHandler<CreateMessageRequestParams, CreateMessageResult>(
4444
RequestMethods.SamplingCreateMessage,
45-
(request, ct) => samplingHandler(request, ct));
45+
(request, ct) => samplingHandler(request, new ClientTokenProgress(this, request?.Meta?.ProgressToken), ct));
4646
}
4747

4848
if (options.Capabilities?.Roots is { } rootsCapability)

src/ModelContextProtocol/Client/McpClientExtensions.cs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,17 +531,34 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat
531531
/// </summary>
532532
/// <param name="chatClient">The <see cref="IChatClient"/> with which to satisfy sampling requests.</param>
533533
/// <returns>The created handler delegate.</returns>
534-
public static Func<CreateMessageRequestParams?, CancellationToken, Task<CreateMessageResult>> CreateSamplingHandler(this IChatClient chatClient)
534+
public static Func<CreateMessageRequestParams?, IProgress<ProgressNotificationValue>, CancellationToken, Task<CreateMessageResult>> CreateSamplingHandler(
535+
this IChatClient chatClient)
535536
{
536537
Throw.IfNull(chatClient);
537538

538-
return async (requestParams, cancellationToken) =>
539+
return async (requestParams, progress, cancellationToken) =>
539540
{
540541
Throw.IfNull(requestParams);
541542

542543
var (messages, options) = requestParams.ToChatClientArguments();
543-
var response = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
544-
return response.ToCreateMessageResult();
544+
var progressToken = requestParams.Meta?.ProgressToken;
545+
int progressValue = 0;
546+
var streamingResponses = chatClient.GetStreamingResponseAsync(
547+
messages, options, cancellationToken);
548+
List<ChatResponseUpdate> updates = [];
549+
await foreach (var streamingResponse in streamingResponses)
550+
{
551+
updates.Add(streamingResponse);
552+
if (progressToken is not null)
553+
{
554+
progress.Report(new()
555+
{
556+
Progress = ++progressValue,
557+
});
558+
}
559+
}
560+
561+
return updates.ToChatResponse().ToCreateMessageResult();
545562
};
546563
}
547564

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using ModelContextProtocol.Client;
2+
using ModelContextProtocol.Protocol.Messages;
3+
4+
namespace ModelContextProtocol;
5+
6+
internal sealed class ClientTokenProgress(IMcpClient client, ProgressToken? progressToken)
7+
: IProgress<ProgressNotificationValue>
8+
{
9+
/// <inheritdoc />
10+
public void Report(ProgressNotificationValue value)
11+
{
12+
if (progressToken is null) return;
13+
_ = client.SendMessageAsync(new JsonRpcNotification()
14+
{
15+
Method = NotificationMethods.ProgressNotification,
16+
Params = new ProgressNotification()
17+
{
18+
ProgressToken = progressToken.Value,
19+
Progress = new()
20+
{
21+
Progress = value.Progress,
22+
Total = value.Total,
23+
Message = value.Message,
24+
},
25+
},
26+
}, CancellationToken.None);
27+
}
28+
}

src/ModelContextProtocol/Protocol/Types/Capabilities.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public class SamplingCapability
5555

5656
/// <summary>Gets or sets the handler for sampling requests.</summary>
5757
[JsonIgnore]
58-
public Func<CreateMessageRequestParams?, CancellationToken, Task<CreateMessageResult>>? SamplingHandler { get; set; }
58+
public Func<CreateMessageRequestParams?, IProgress<ProgressNotificationValue>, CancellationToken, Task<CreateMessageResult>>? SamplingHandler { get; set; }
5959
}
6060

6161
/// <summary>

src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,4 @@
44
/// Sent from the client to request a list of prompts and prompt templates the server has.
55
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/">See the schema for details</see>
66
/// </summary>
7-
public class ListPromptsRequestParams
8-
{
9-
/// <summary>
10-
/// An opaque token representing the current pagination position.
11-
/// If provided, the server should return results starting after this cursor.
12-
/// </summary>
13-
[System.Text.Json.Serialization.JsonPropertyName("cursor")]
14-
public string? Cursor { get; init; }
15-
}
7+
public class ListPromptsRequestParams : PaginatedRequestParams;

src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,4 @@
44
/// Sent from the client to request a list of resource templates the server has.
55
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/">See the schema for details</see>
66
/// </summary>
7-
public class ListResourceTemplatesRequestParams
8-
{
9-
/// <summary>
10-
/// An opaque token representing the current pagination position.
11-
/// If provided, the server should return results starting after this cursor.
12-
/// </summary>
13-
[System.Text.Json.Serialization.JsonPropertyName("cursor")]
14-
public string? Cursor { get; init; }
15-
}
7+
public class ListResourceTemplatesRequestParams : PaginatedRequestParams;

src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,4 @@
44
/// Sent from the client to request a list of resources the server has.
55
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/">See the schema for details</see>
66
/// </summary>
7-
public class ListResourcesRequestParams
8-
{
9-
/// <summary>
10-
/// An opaque token representing the current pagination position.
11-
/// If provided, the server should return results starting after this cursor.
12-
/// </summary>
13-
[System.Text.Json.Serialization.JsonPropertyName("cursor")]
14-
public string? Cursor { get; init; }
15-
}
7+
public class ListResourcesRequestParams : PaginatedRequestParams;

src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,4 @@ namespace ModelContextProtocol.Protocol.Types;
66
/// A request from the server to get a list of root URIs from the client.
77
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/">See the schema for details</see>
88
/// </summary>
9-
public class ListRootsRequestParams
10-
{
11-
/// <summary>
12-
/// Optional progress token for out-of-band progress notifications.
13-
/// </summary>
14-
[System.Text.Json.Serialization.JsonPropertyName("progressToken")]
15-
public ProgressToken? ProgressToken { get; init; }
16-
}
9+
public class ListRootsRequestParams : RequestParams;

src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,4 @@
44
/// Sent from the client to request a list of tools the server has.
55
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/">See the schema for details</see>
66
/// </summary>
7-
public class ListToolsRequestParams
8-
{
9-
/// <summary>
10-
/// An opaque token representing the current pagination position.
11-
/// If provided, the server should return results starting after this cursor.
12-
/// </summary>
13-
[System.Text.Json.Serialization.JsonPropertyName("cursor")]
14-
public string? Cursor { get; init; }
15-
}
7+
public class ListToolsRequestParams : PaginatedRequestParams;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
namespace ModelContextProtocol.Protocol.Types;
2+
3+
/// <summary>
4+
/// Used as a base class for paginated requests.
5+
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.json">See the schema for details</see>
6+
/// </summary>
7+
public class PaginatedRequestParams : RequestParams
8+
{
9+
/// <summary>
10+
/// An opaque token representing the current pagination position.
11+
/// If provided, the server should return results starting after this cursor.
12+
/// </summary>
13+
[System.Text.Json.Serialization.JsonPropertyName("cursor")]
14+
public string? Cursor { get; init; }
15+
}

0 commit comments

Comments
 (0)