Skip to content

Commit 75fe8ee

Browse files
halter73Copilot
andcommitted
Address review feedback: drop typed InputResponse accessors and resolve input requests with WhenAll+CTS
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 18c0df7 commit 75fe8ee

11 files changed

Lines changed: 131 additions & 64 deletions

File tree

docs/concepts/elicitation/elicitation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ public static string ElicitWithMrtr(
188188
// On retry, process the client's elicitation response
189189
if (context.Params!.InputResponses?.TryGetValue("user_input", out var response) is true)
190190
{
191-
var elicitResult = response.ElicitationResult;
191+
var elicitResult = response.Deserialize(InputResponse.ElicitResultTypeInfo);
192192
return elicitResult?.Action == "accept"
193193
? $"User accepted: {elicitResult.Content?.FirstOrDefault().Value}"
194194
: "User declined.";

docs/concepts/mrtr/mrtr.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public static string AnswerTool(
9696
// On retry, process the client's responses
9797
if (requestState is not null && inputResponses is not null)
9898
{
99-
var elicitResult = inputResponses["user_answer"].ElicitationResult;
99+
var elicitResult = inputResponses["user_answer"].Deserialize(InputResponse.ElicitResultTypeInfo);
100100
return $"You answered: {elicitResult?.Content?.FirstOrDefault().Value}";
101101
}
102102

@@ -135,11 +135,11 @@ When the client retries a tool call, the retry data is available on the request
135135
- <xref:ModelContextProtocol.Protocol.RequestParams.InputResponses> — a dictionary of client responses keyed by the same keys used in `inputRequests`.
136136
- <xref:ModelContextProtocol.Protocol.RequestParams.RequestState> — the opaque state string echoed back by the client.
137137

138-
Each `InputResponse` has typed accessors for the response type:
138+
Use <xref:ModelContextProtocol.Protocol.InputResponse.Deserialize*> with the `JsonTypeInfo<T>` matching the response type. The expected type follows from the matching <xref:ModelContextProtocol.Protocol.InputRequest.Method> in the original `inputRequests` map — there is no on-the-wire discriminator.
139139

140-
- `ElicitationResult`the result of an elicitation request.
141-
- `SamplingResult`the result of a sampling request.
142-
- `RootsResult` — the result of a roots list request.
140+
- Elicitation`response.Deserialize(InputResponse.ElicitResultTypeInfo)`
141+
- Sampling`response.Deserialize(InputResponse.SamplingResultTypeInfo)`
142+
- Roots list `response.Deserialize(InputResponse.RootsResultTypeInfo)`
143143

144144
### Load shedding with requestState-only responses
145145

@@ -191,14 +191,14 @@ public static string WizardTool(
191191

192192
if (requestState == "step-2" && inputResponses is not null)
193193
{
194-
var name = inputResponses["name"].ElicitationResult?.Content?.FirstOrDefault().Value;
195-
var age = inputResponses["age"].ElicitationResult?.Content?.FirstOrDefault().Value;
194+
var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultTypeInfo)?.Content?.FirstOrDefault().Value;
195+
var age = inputResponses["age"].Deserialize(InputResponse.ElicitResultTypeInfo)?.Content?.FirstOrDefault().Value;
196196
return $"Welcome, {name}! You are {age} years old.";
197197
}
198198

199199
if (requestState == "step-1" && inputResponses is not null)
200200
{
201-
var name = inputResponses["name"].ElicitationResult?.Content?.FirstOrDefault().Value;
201+
var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultTypeInfo)?.Content?.FirstOrDefault().Value;
202202

203203
// Second round — ask for age
204204
throw new InputRequiredException(

docs/concepts/roots/roots.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public static string ListRootsWithMrtr(
122122
// On retry, process the client's roots response
123123
if (context.Params!.InputResponses?.TryGetValue("get_roots", out var response) is true)
124124
{
125-
var roots = response.RootsResult?.Roots ?? [];
125+
var roots = response.Deserialize(InputResponse.RootsResultTypeInfo)?.Roots ?? [];
126126
return $"Found {roots.Count} roots: {string.Join(", ", roots.Select(r => r.Uri))}";
127127
}
128128

docs/concepts/sampling/sampling.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public static string SampleWithMrtr(
139139
// On retry, process the client's sampling response
140140
if (context.Params!.InputResponses?.TryGetValue("llm_call", out var response) is true)
141141
{
142-
var text = response.SamplingResult?.Content
142+
var text = response.Deserialize(InputResponse.SamplingResultTypeInfo)?.Content
143143
.OfType<TextContentBlock>().FirstOrDefault()?.Text;
144144
return $"LLM said: {text}";
145145
}

src/ModelContextProtocol.Core/Client/McpClientImpl.cs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,20 +566,41 @@ private async ValueTask<IDictionary<string, InputResponse>> ResolveInputRequests
566566
IDictionary<string, InputRequest> inputRequests,
567567
CancellationToken cancellationToken)
568568
{
569-
var responses = new Dictionary<string, InputResponse>(inputRequests.Count);
569+
// Resolve all input requests concurrently. If any fails, cancel the rest so user-facing
570+
// handlers (sampling/elicitation prompts) don't keep running for a request whose caller
571+
// has already given up, and ensure exceptions from late-completing tasks are observed.
572+
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
570573

571-
// Resolve all input requests concurrently
572-
var tasks = new List<(string Key, Task<InputResponse> Task)>(inputRequests.Count);
574+
var keyed = new (string Key, Task<InputResponse> Task)[inputRequests.Count];
575+
int i = 0;
573576
foreach (var kvp in inputRequests)
574577
{
575-
tasks.Add((kvp.Key, ResolveInputRequestAsync(kvp.Value, cancellationToken)));
578+
keyed[i++] = (kvp.Key, ResolveInputRequestAsync(kvp.Value, linkedCts.Token));
576579
}
577580

578-
foreach (var entry in tasks)
581+
try
579582
{
580-
responses[entry.Key] = await entry.Task.ConfigureAwait(false);
583+
await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false);
584+
}
585+
catch
586+
{
587+
linkedCts.Cancel();
588+
try
589+
{
590+
await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false);
591+
}
592+
catch
593+
{
594+
// Observed; the original exception is the one we want to surface.
595+
}
596+
throw;
581597
}
582598

599+
var responses = new Dictionary<string, InputResponse>(keyed.Length);
600+
foreach (var (key, task) in keyed)
601+
{
602+
responses[key] = task.Result;
603+
}
583604
return responses;
584605
}
585606

src/ModelContextProtocol.Core/Protocol/InputResponse.cs

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Diagnostics.CodeAnalysis;
22
using System.Text.Json;
33
using System.Text.Json.Serialization;
4+
using System.Text.Json.Serialization.Metadata;
45

56
namespace ModelContextProtocol.Protocol;
67

@@ -28,7 +29,10 @@ public sealed class InputResponse
2829
/// Gets or sets the raw JSON element representing the response.
2930
/// </summary>
3031
/// <remarks>
31-
/// Use <see cref="Deserialize{T}"/> or the typed factory methods to work with concrete response types.
32+
/// Use <see cref="Deserialize{T}"/> with the <c>JsonTypeInfo&lt;T&gt;</c> matching the
33+
/// associated <see cref="InputRequest.Method"/> — for elicitation, sampling, or roots see
34+
/// <see cref="ElicitResultTypeInfo"/>, <see cref="SamplingResultTypeInfo"/>, and
35+
/// <see cref="RootsResultTypeInfo"/>.
3236
/// </remarks>
3337
[JsonIgnore]
3438
public JsonElement RawValue { get; set; }
@@ -43,28 +47,25 @@ public sealed class InputResponse
4347
JsonSerializer.Deserialize(RawValue, typeInfo);
4448

4549
/// <summary>
46-
/// Gets the response as a <see cref="CreateMessageResult"/>.
50+
/// Gets the <see cref="JsonTypeInfo{T}"/> for <see cref="ElicitResult"/>, suitable for use with
51+
/// <see cref="Deserialize{T}"/> when the corresponding <see cref="InputRequest.Method"/> is
52+
/// <see cref="RequestMethods.ElicitationCreate"/>.
4753
/// </summary>
48-
/// <returns>The deserialized sampling result, or <see langword="null"/> if deserialization fails.</returns>
49-
[JsonIgnore]
50-
public CreateMessageResult? SamplingResult =>
51-
JsonSerializer.Deserialize(RawValue, McpJsonUtilities.JsonContext.Default.CreateMessageResult);
54+
public static JsonTypeInfo<ElicitResult> ElicitResultTypeInfo => McpJsonUtilities.JsonContext.Default.ElicitResult;
5255

5356
/// <summary>
54-
/// Gets the response as an <see cref="ElicitResult"/>.
57+
/// Gets the <see cref="JsonTypeInfo{T}"/> for <see cref="CreateMessageResult"/>, suitable for use with
58+
/// <see cref="Deserialize{T}"/> when the corresponding <see cref="InputRequest.Method"/> is
59+
/// <see cref="RequestMethods.SamplingCreateMessage"/>.
5560
/// </summary>
56-
/// <returns>The deserialized elicitation result, or <see langword="null"/> if deserialization fails.</returns>
57-
[JsonIgnore]
58-
public ElicitResult? ElicitationResult =>
59-
JsonSerializer.Deserialize(RawValue, McpJsonUtilities.JsonContext.Default.ElicitResult);
61+
public static JsonTypeInfo<CreateMessageResult> SamplingResultTypeInfo => McpJsonUtilities.JsonContext.Default.CreateMessageResult;
6062

6163
/// <summary>
62-
/// Gets the response as a <see cref="ListRootsResult"/>.
64+
/// Gets the <see cref="JsonTypeInfo{T}"/> for <see cref="ListRootsResult"/>, suitable for use with
65+
/// <see cref="Deserialize{T}"/> when the corresponding <see cref="InputRequest.Method"/> is
66+
/// <see cref="RequestMethods.RootsList"/>.
6367
/// </summary>
64-
/// <returns>The deserialized roots list result, or <see langword="null"/> if deserialization fails.</returns>
65-
[JsonIgnore]
66-
public ListRootsResult? RootsResult =>
67-
JsonSerializer.Deserialize(RawValue, McpJsonUtilities.JsonContext.Default.ListRootsResult);
68+
public static JsonTypeInfo<ListRootsResult> RootsResultTypeInfo => McpJsonUtilities.JsonContext.Default.ListRootsResult;
6869

6970
/// <summary>
7071
/// Creates an <see cref="InputResponse"/> from a <see cref="CreateMessageResult"/>.

src/ModelContextProtocol.Core/Server/McpServerImpl.cs

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,11 +1217,7 @@ internal bool IsLowLevelMrtrAvailable() =>
12171217
}
12181218

12191219
// Resolve each input request by sending the corresponding JSON-RPC call to the client.
1220-
var inputResponses = new Dictionary<string, InputResponse>(inputRequests.Count);
1221-
foreach (var kvp in inputRequests)
1222-
{
1223-
inputResponses[kvp.Key] = await ResolveInputRequestAsync(kvp.Value, cancellationToken).ConfigureAwait(false);
1224-
}
1220+
var inputResponses = await ResolveInputRequestsAsync(inputRequests, cancellationToken).ConfigureAwait(false);
12251221

12261222
// Reconstruct request params with inputResponses and requestState for the retry.
12271223
var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject();
@@ -1244,6 +1240,52 @@ internal bool IsLowLevelMrtrAvailable() =>
12441240
}
12451241
}
12461242

1243+
/// <summary>
1244+
/// Resolves a batch of MRTR input requests concurrently by dispatching each as a standard
1245+
/// JSON-RPC request to the client. On the first failure all remaining handlers are cancelled
1246+
/// so user-facing flows (sampling/elicitation prompts) don't keep running once the caller has
1247+
/// given up, and exceptions from late-completing tasks are observed before the original
1248+
/// exception is rethrown.
1249+
/// </summary>
1250+
private async Task<IDictionary<string, InputResponse>> ResolveInputRequestsAsync(
1251+
IDictionary<string, InputRequest> inputRequests,
1252+
CancellationToken cancellationToken)
1253+
{
1254+
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
1255+
1256+
var keyed = new (string Key, Task<InputResponse> Task)[inputRequests.Count];
1257+
int i = 0;
1258+
foreach (var kvp in inputRequests)
1259+
{
1260+
keyed[i++] = (kvp.Key, ResolveInputRequestAsync(kvp.Value, linkedCts.Token));
1261+
}
1262+
1263+
try
1264+
{
1265+
await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false);
1266+
}
1267+
catch
1268+
{
1269+
linkedCts.Cancel();
1270+
try
1271+
{
1272+
await Task.WhenAll(Array.ConvertAll(keyed, k => k.Task)).ConfigureAwait(false);
1273+
}
1274+
catch
1275+
{
1276+
// Observed; the original exception is the one we want to surface.
1277+
}
1278+
throw;
1279+
}
1280+
1281+
var responses = new Dictionary<string, InputResponse>(keyed.Length);
1282+
foreach (var (key, task) in keyed)
1283+
{
1284+
responses[key] = task.Result;
1285+
}
1286+
return responses;
1287+
}
1288+
12471289
/// <summary>
12481290
/// Resolves a single MRTR <see cref="InputRequest"/> by dispatching it as a standard JSON-RPC
12491291
/// request to the client. This is the server-side mirror of the client's input resolution logic,

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.Mrtr.cs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ private static async Task<string> MrtrMixed(McpServer server, RequestContext<Cal
9191
// Round 3 entry: confirmation from round 2 available. Transition to await API.
9292
if (state == "round-2" && responses?.TryGetValue("confirm", out var confirmResponse) == true)
9393
{
94-
var confirmation = confirmResponse.ElicitationResult?.Action ?? "unknown";
94+
var confirmation = confirmResponse.Deserialize(InputResponse.ElicitResultTypeInfo)?.Action ?? "unknown";
9595

9696
// Await API: sequential sampling then elicitation
9797
var sampleResult = await server.SampleAsync(new CreateMessageRequestParams
@@ -114,10 +114,10 @@ private static async Task<string> MrtrMixed(McpServer server, RequestContext<Cal
114114
// Round 2 entry: parallel results from round 1 available.
115115
if (state == "round-1" && responses is not null)
116116
{
117-
var name = responses["name"].ElicitationResult?.Content?.FirstOrDefault().Value;
118-
var weather = responses["weather"].SamplingResult?.Content
117+
var name = responses["name"].Deserialize(InputResponse.ElicitResultTypeInfo)?.Content?.FirstOrDefault().Value;
118+
var weather = responses["weather"].Deserialize(InputResponse.SamplingResultTypeInfo)?.Content
119119
.OfType<TextContentBlock>().FirstOrDefault()?.Text ?? "";
120-
var root = responses["roots"].RootsResult?.Roots?.FirstOrDefault()?.Name ?? "";
120+
var root = responses["roots"].Deserialize(InputResponse.RootsResultTypeInfo)?.Roots?.FirstOrDefault()?.Name ?? "";
121121

122122
// Exception API: single elicitation with requestState
123123
throw new InputRequiredException(
@@ -305,7 +305,7 @@ private static string MrtrElicit(RequestContext<CallToolRequestParams> context)
305305
if (context.Params!.InputResponses is { } responses &&
306306
responses.TryGetValue("user_input", out var response))
307307
{
308-
return $"elicit-ok:{response.ElicitationResult?.Action}";
308+
return $"elicit-ok:{response.Deserialize(InputResponse.ElicitResultTypeInfo)?.Action}";
309309
}
310310

311311
throw new InputRequiredException(
@@ -329,7 +329,7 @@ public async Task Mrtr_LowLevel_Roots_CompletesViaMrtr()
329329
if (context.Params!.InputResponses is { } responses &&
330330
responses.TryGetValue("roots", out var response))
331331
{
332-
var roots = response.RootsResult?.Roots;
332+
var roots = response.Deserialize(InputResponse.RootsResultTypeInfo)?.Roots;
333333
return $"roots-ok:{string.Join(",", roots?.Select(r => r.Uri) ?? [])}";
334334
}
335335

@@ -363,13 +363,13 @@ private static string MrtrMulti(RequestContext<CallToolRequestParams> context)
363363

364364
if (requestState == "round-2" && inputResponses is not null)
365365
{
366-
var greeting = inputResponses["greeting"].ElicitationResult?.Action;
366+
var greeting = inputResponses["greeting"].Deserialize(InputResponse.ElicitResultTypeInfo)?.Action;
367367
return $"multi-done:greeting={greeting}";
368368
}
369369

370370
if (requestState == "round-1" && inputResponses is not null)
371371
{
372-
var name = inputResponses["name"].ElicitationResult?.Content?.FirstOrDefault().Value;
372+
var name = inputResponses["name"].Deserialize(InputResponse.ElicitResultTypeInfo)?.Content?.FirstOrDefault().Value;
373373
throw new InputRequiredException(
374374
inputRequests: new Dictionary<string, InputRequest>
375375
{
@@ -475,11 +475,11 @@ private static string MrtrConcurrentThree(RequestContext<CallToolRequestParams>
475475
responses.ContainsKey("sample") &&
476476
responses.ContainsKey("roots"))
477477
{
478-
var elicitAction = responses["elicit"].ElicitationResult?.Action;
479-
var sampleText = responses["sample"].SamplingResult?
478+
var elicitAction = responses["elicit"].Deserialize(InputResponse.ElicitResultTypeInfo)?.Action;
479+
var sampleText = responses["sample"].Deserialize(InputResponse.SamplingResultTypeInfo)?
480480
.Content.OfType<TextContentBlock>().FirstOrDefault()?.Text;
481481
var rootUris = string.Join(",",
482-
responses["roots"].RootsResult?.Roots.Select(r => r.Uri) ?? []);
482+
responses["roots"].Deserialize(InputResponse.RootsResultTypeInfo)?.Roots.Select(r => r.Uri) ?? []);
483483
return $"all-ok:elicit={elicitAction},sample={sampleText},roots={rootUris}";
484484
}
485485

@@ -596,7 +596,7 @@ public async Task Mrtr_Backcompat_Roots_ResolvedViaLegacyJsonRpc()
596596
if (context.Params!.InputResponses is { } responses &&
597597
responses.TryGetValue("roots", out var response))
598598
{
599-
var roots = response.RootsResult?.Roots;
599+
var roots = response.Deserialize(InputResponse.RootsResultTypeInfo)?.Roots;
600600
return $"roots-ok:{roots?.FirstOrDefault()?.Name}";
601601
}
602602

@@ -633,8 +633,8 @@ public async Task Mrtr_Backcompat_MultipleInputRequests_ResolvedViaLegacyJsonRpc
633633
responses.TryGetValue("confirm", out var elicitResponse) &&
634634
responses.TryGetValue("summarize", out var sampleResponse))
635635
{
636-
var action = elicitResponse.ElicitationResult?.Action;
637-
var text = sampleResponse.SamplingResult?.Content.OfType<TextContentBlock>().FirstOrDefault()?.Text;
636+
var action = elicitResponse.Deserialize(InputResponse.ElicitResultTypeInfo)?.Action;
637+
var text = sampleResponse.Deserialize(InputResponse.SamplingResultTypeInfo)?.Content.OfType<TextContentBlock>().FirstOrDefault()?.Text;
638638
return $"both:{action}:{text}";
639639
}
640640

tests/ModelContextProtocol.ConformanceServer/Prompts/IncompleteResultPrompts.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public static GetPromptResult IncompleteResultPrompt(RequestContext<GetPromptReq
2323
if (context.Params!.InputResponses is { } responses &&
2424
responses.TryGetValue("user_context", out var response))
2525
{
26-
var elicit = response.ElicitationResult;
26+
var elicit = response.Deserialize(InputResponse.ElicitResultTypeInfo);
2727
var contextValue = TryReadString(elicit?.Content, "context") ?? "(unknown)";
2828
return new GetPromptResult
2929
{

0 commit comments

Comments
 (0)