Skip to content

Commit 3c30338

Browse files
halter73Copilot
andcommitted
Remove internal MRTR members from mockable base classes
Move MRTR logic out of McpServer and McpClient base classes into their internal implementations, keeping the mockable API surface clean. Server side: - Remove McpServer.ActiveMrtrContext (was internal) - Add MRTR interception to DestinationBoundMcpServer.SendRequestAsync with task guard (SampleAsTaskAsync/ElicitAsTaskAsync bypass MRTR) - Remove MRTR branches from SampleAsync, ElicitAsync, RequestRootsCoreAsync - Task status tracking (InputRequired) now works during MRTR Client side: - Remove McpClient.ResolveInputRequestsAsync (was internal abstract) - Move MRTR retry loop into McpClientImpl.SendRequestAsync override - Replace SendRequestWithMrtrAsync with existing McpSession typed helper - Make resolve methods private on McpClientImpl Add 4 new tests for MRTR+Tasks interaction: - Task-augmented tool call with MRTR sampling - MRTR elicitation through tool call - SampleAsTaskAsync bypasses MRTR interception - MRTR tool call and task-based sampling coexist Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent e1bd3f6 commit 3c30338

8 files changed

Lines changed: 407 additions & 158 deletions

File tree

src/ModelContextProtocol.Core/Client/McpClient.Methods.cs

Lines changed: 16 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ public ValueTask<ListToolsResult> ListToolsAsync(
184184
{
185185
Throw.IfNull(requestParams);
186186

187-
return SendRequestWithMrtrAsync(
187+
return SendRequestAsync(
188188
RequestMethods.ToolsList,
189189
requestParams,
190190
McpJsonUtilities.JsonContext.Default.ListToolsRequestParams,
191191
McpJsonUtilities.JsonContext.Default.ListToolsResult,
192-
cancellationToken);
192+
cancellationToken: cancellationToken);
193193
}
194194

195195
/// <summary>
@@ -240,12 +240,12 @@ public ValueTask<ListPromptsResult> ListPromptsAsync(
240240
{
241241
Throw.IfNull(requestParams);
242242

243-
return SendRequestWithMrtrAsync(
243+
return SendRequestAsync(
244244
RequestMethods.PromptsList,
245245
requestParams,
246246
McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams,
247247
McpJsonUtilities.JsonContext.Default.ListPromptsResult,
248-
cancellationToken);
248+
cancellationToken: cancellationToken);
249249
}
250250

251251
/// <summary>
@@ -294,12 +294,12 @@ public ValueTask<GetPromptResult> GetPromptAsync(
294294
{
295295
Throw.IfNull(requestParams);
296296

297-
return SendRequestWithMrtrAsync(
297+
return SendRequestAsync(
298298
RequestMethods.PromptsGet,
299299
requestParams,
300300
McpJsonUtilities.JsonContext.Default.GetPromptRequestParams,
301301
McpJsonUtilities.JsonContext.Default.GetPromptResult,
302-
cancellationToken);
302+
cancellationToken: cancellationToken);
303303
}
304304

305305
/// <summary>
@@ -350,12 +350,12 @@ public ValueTask<ListResourceTemplatesResult> ListResourceTemplatesAsync(
350350
{
351351
Throw.IfNull(requestParams);
352352

353-
return SendRequestWithMrtrAsync(
353+
return SendRequestAsync(
354354
RequestMethods.ResourcesTemplatesList,
355355
requestParams,
356356
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams,
357357
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult,
358-
cancellationToken);
358+
cancellationToken: cancellationToken);
359359
}
360360

361361
/// <summary>
@@ -406,12 +406,12 @@ public ValueTask<ListResourcesResult> ListResourcesAsync(
406406
{
407407
Throw.IfNull(requestParams);
408408

409-
return SendRequestWithMrtrAsync(
409+
return SendRequestAsync(
410410
RequestMethods.ResourcesList,
411411
requestParams,
412412
McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams,
413413
McpJsonUtilities.JsonContext.Default.ListResourcesResult,
414-
cancellationToken);
414+
cancellationToken: cancellationToken);
415415
}
416416

417417
/// <summary>
@@ -490,12 +490,12 @@ public ValueTask<ReadResourceResult> ReadResourceAsync(
490490
{
491491
Throw.IfNull(requestParams);
492492

493-
return SendRequestWithMrtrAsync(
493+
return SendRequestAsync(
494494
RequestMethods.ResourcesRead,
495495
requestParams,
496496
McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams,
497497
McpJsonUtilities.JsonContext.Default.ReadResourceResult,
498-
cancellationToken);
498+
cancellationToken: cancellationToken);
499499
}
500500

501501
/// <summary>
@@ -541,12 +541,12 @@ public ValueTask<CompleteResult> CompleteAsync(
541541
{
542542
Throw.IfNull(requestParams);
543543

544-
return SendRequestWithMrtrAsync(
544+
return SendRequestAsync(
545545
RequestMethods.CompletionComplete,
546546
requestParams,
547547
McpJsonUtilities.JsonContext.Default.CompleteRequestParams,
548548
McpJsonUtilities.JsonContext.Default.CompleteResult,
549-
cancellationToken);
549+
cancellationToken: cancellationToken);
550550
}
551551

552552
/// <summary>
@@ -906,12 +906,12 @@ public ValueTask<CallToolResult> CallToolAsync(
906906
{
907907
Throw.IfNull(requestParams);
908908

909-
return SendRequestWithMrtrAsync(
909+
return SendRequestAsync(
910910
RequestMethods.ToolsCall,
911911
requestParams,
912912
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
913913
McpJsonUtilities.JsonContext.Default.CallToolResult,
914-
cancellationToken);
914+
cancellationToken: cancellationToken);
915915
}
916916

917917
/// <summary>
@@ -1290,91 +1290,6 @@ public Task SetLoggingLevelAsync(
12901290
cancellationToken: cancellationToken).AsTask();
12911291
}
12921292

1293-
/// <summary>
1294-
/// Sends a request with MRTR (Multi Round-Trip Request) support. If the server returns an
1295-
/// <see cref="IncompleteResult"/>, this method automatically resolves the input requests
1296-
/// via the client's handlers and retries until a complete result is obtained.
1297-
/// </summary>
1298-
private async ValueTask<TResult> SendRequestWithMrtrAsync<TParams, TResult>(
1299-
string method,
1300-
TParams parameters,
1301-
JsonTypeInfo<TParams> parametersTypeInfo,
1302-
JsonTypeInfo<TResult> resultTypeInfo,
1303-
CancellationToken cancellationToken)
1304-
where TParams : RequestParams
1305-
where TResult : Result
1306-
{
1307-
const int maxRetries = 10;
1308-
1309-
for (int attempt = 0; attempt <= maxRetries; attempt++)
1310-
{
1311-
JsonRpcRequest jsonRpcRequest = new()
1312-
{
1313-
Method = method,
1314-
Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo),
1315-
};
1316-
1317-
JsonRpcResponse response = await SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false);
1318-
1319-
// Check if the result is an IncompleteResult by looking at result_type
1320-
if (response.Result is JsonObject resultObj &&
1321-
resultObj.TryGetPropertyValue("result_type", out var resultTypeNode) &&
1322-
resultTypeNode?.GetValue<string>() is "incomplete")
1323-
{
1324-
var incompleteResult = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.JsonContext.Default.IncompleteResult)
1325-
?? throw new JsonException("Failed to deserialize IncompleteResult.");
1326-
1327-
if (incompleteResult.InputRequests is { Count: > 0 } inputRequests)
1328-
{
1329-
IDictionary<string, InputResponse> inputResponses =
1330-
await ResolveInputRequestsAsync(inputRequests, cancellationToken).ConfigureAwait(false);
1331-
1332-
// Serialize input responses into the parameters for the retry
1333-
var paramsNode = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo) as JsonObject
1334-
?? throw new JsonException("Failed to serialize request parameters as JsonObject.");
1335-
1336-
paramsNode["inputResponses"] = JsonSerializer.SerializeToNode(
1337-
inputResponses, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse);
1338-
1339-
if (incompleteResult.RequestState is { } requestState)
1340-
{
1341-
paramsNode["requestState"] = requestState;
1342-
}
1343-
1344-
// Deserialize back to TParams to pick up the inputResponses and requestState
1345-
parameters = JsonSerializer.Deserialize(paramsNode, parametersTypeInfo)
1346-
?? throw new JsonException("Failed to deserialize retry parameters.");
1347-
}
1348-
else if (incompleteResult.RequestState is not null)
1349-
{
1350-
// No input requests but has requestState (e.g., load shedding) — just retry with state
1351-
var paramsNode = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo) as JsonObject
1352-
?? throw new JsonException("Failed to serialize request parameters as JsonObject.");
1353-
1354-
paramsNode["requestState"] = incompleteResult.RequestState;
1355-
1356-
// Remove any old inputResponses from previous iteration
1357-
paramsNode.Remove("inputResponses");
1358-
1359-
parameters = JsonSerializer.Deserialize(paramsNode, parametersTypeInfo)
1360-
?? throw new JsonException("Failed to deserialize retry parameters.");
1361-
}
1362-
else
1363-
{
1364-
throw new McpException("Server returned an IncompleteResult without inputRequests or requestState.");
1365-
}
1366-
1367-
continue; // retry with the updated parameters
1368-
}
1369-
1370-
// Normal complete result
1371-
return JsonSerializer.Deserialize(response.Result, resultTypeInfo)
1372-
?? throw new JsonException("Unexpected JSON result in response.");
1373-
}
1374-
1375-
throw new McpException($"Server returned IncompleteResult more than {maxRetries} times.");
1376-
}
1377-
13781293
/// <summary>Converts a dictionary with <see cref="object"/> values to a dictionary with <see cref="JsonElement"/> values.</summary>
13791294
private static Dictionary<string, JsonElement>? ToArgumentsDictionary(
13801295
IReadOnlyDictionary<string, object?>? arguments, JsonSerializerOptions options)

src/ModelContextProtocol.Core/Client/McpClient.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,4 @@ protected McpClient()
7171
/// </remarks>
7272
public abstract Task<ClientCompletionDetails> Completion { get; }
7373

74-
/// <summary>
75-
/// Resolves input requests from an <see cref="IncompleteResult"/> by dispatching each request
76-
/// to the appropriate handler (sampling, elicitation, or roots).
77-
/// </summary>
78-
/// <param name="inputRequests">The input requests to resolve.</param>
79-
/// <param name="cancellationToken">A cancellation token.</param>
80-
/// <returns>A dictionary of responses keyed by the same keys as the input requests.</returns>
81-
internal abstract ValueTask<IDictionary<string, InputResponse>> ResolveInputRequestsAsync(
82-
IDictionary<string, InputRequest> inputRequests,
83-
CancellationToken cancellationToken);
8474
}

src/ModelContextProtocol.Core/Client/McpClientImpl.cs

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ private void RegisterTaskHandlers(RequestHandlers requestHandlers, IMcpTaskStore
532532
public override Task<ClientCompletionDetails> Completion => _sessionHandler.CompletionTask;
533533

534534
/// <inheritdoc/>
535-
internal override async ValueTask<IDictionary<string, InputResponse>> ResolveInputRequestsAsync(
535+
private async ValueTask<IDictionary<string, InputResponse>> ResolveInputRequestsAsync(
536536
IDictionary<string, InputRequest> inputRequests,
537537
CancellationToken cancellationToken)
538538
{
@@ -710,8 +710,62 @@ internal void ResumeSession(ResumeClientSessionOptions resumeOptions)
710710
}
711711

712712
/// <inheritdoc/>
713-
public override Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
714-
=> _sessionHandler.SendRequestAsync(request, cancellationToken);
713+
public override async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
714+
{
715+
const int maxRetries = 10;
716+
717+
for (int attempt = 0; attempt <= maxRetries; attempt++)
718+
{
719+
JsonRpcResponse response = await _sessionHandler.SendRequestAsync(request, cancellationToken).ConfigureAwait(false);
720+
721+
// Check if the result is an IncompleteResult by looking at result_type.
722+
if (response.Result is JsonObject resultObj &&
723+
resultObj.TryGetPropertyValue("result_type", out var resultTypeNode) &&
724+
resultTypeNode?.GetValue<string>() is "incomplete")
725+
{
726+
var incompleteResult = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.JsonContext.Default.IncompleteResult)
727+
?? throw new JsonException("Failed to deserialize IncompleteResult.");
728+
729+
if (incompleteResult.InputRequests is { Count: > 0 } inputRequests)
730+
{
731+
IDictionary<string, InputResponse> inputResponses =
732+
await ResolveInputRequestsAsync(inputRequests, cancellationToken).ConfigureAwait(false);
733+
734+
// Clone the original request params and add inputResponses + requestState for the retry.
735+
var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject();
736+
737+
paramsObj["inputResponses"] = JsonSerializer.SerializeToNode(
738+
inputResponses, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse);
739+
740+
if (incompleteResult.RequestState is { } requestState)
741+
{
742+
paramsObj["requestState"] = requestState;
743+
}
744+
745+
request = new JsonRpcRequest { Method = request.Method, Params = paramsObj };
746+
}
747+
else if (incompleteResult.RequestState is not null)
748+
{
749+
// No input requests but has requestState (e.g., load shedding) — just retry with state.
750+
var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject();
751+
paramsObj["requestState"] = incompleteResult.RequestState;
752+
paramsObj.Remove("inputResponses");
753+
754+
request = new JsonRpcRequest { Method = request.Method, Params = paramsObj };
755+
}
756+
else
757+
{
758+
throw new McpException("Server returned an IncompleteResult without inputRequests or requestState.");
759+
}
760+
761+
continue; // retry with the updated request
762+
}
763+
764+
return response;
765+
}
766+
767+
throw new McpException($"Server returned IncompleteResult more than {maxRetries} times.");
768+
}
715769

716770
/// <inheritdoc/>
717771
public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)

src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ModelContextProtocol.Protocol;
2-
using System.Diagnostics;
2+
using System.Text.Json;
3+
using System.Text.Json.Nodes;
34

45
namespace ModelContextProtocol.Server;
56

@@ -15,6 +16,12 @@ internal sealed class DestinationBoundMcpServer(McpServerImpl server, ITransport
1516
public override IServiceProvider? Services => server.Services;
1617
public override LoggingLevel? LoggingLevel => server.LoggingLevel;
1718

19+
/// <summary>
20+
/// Gets or sets the MRTR context for the current request, if any.
21+
/// Set by <see cref="McpServerImpl.CreateDestinationBoundServer"/> when an MRTR-aware handler invocation is in progress.
22+
/// </summary>
23+
internal MrtrContext? ActiveMrtrContext { get; set; }
24+
1825
public override ValueTask DisposeAsync() => server.DisposeAsync();
1926

2027
public override IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, ValueTask> handler) => server.RegisterNotificationHandler(method, handler);
@@ -39,6 +46,16 @@ public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken
3946

4047
public override Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
4148
{
49+
// When an MRTR context is active, intercept server-to-client requests (sampling, elicitation, roots)
50+
// and route them through the MRTR mechanism instead of sending them over the wire.
51+
// Task-based requests (SampleAsTaskAsync/ElicitAsTaskAsync) have a "task" property on their params
52+
// and expect a CreateTaskResult response, so they must bypass MRTR and go over the wire.
53+
if (ActiveMrtrContext is { } mrtrContext &&
54+
!(request.Params is JsonObject paramsObj && paramsObj.ContainsKey("task")))
55+
{
56+
return SendRequestViaMrtrAsync(mrtrContext, request, cancellationToken);
57+
}
58+
4259
if (request.Context is not null)
4360
{
4461
throw new ArgumentException("Only transports can provide a JsonRpcMessageContext.");
@@ -51,4 +68,23 @@ public override Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, C
5168

5269
return server.SendRequestAsync(request, cancellationToken);
5370
}
71+
72+
private async Task<JsonRpcResponse> SendRequestViaMrtrAsync(
73+
MrtrContext mrtrContext, JsonRpcRequest request, CancellationToken cancellationToken)
74+
{
75+
var inputRequest = new InputRequest
76+
{
77+
Method = request.Method,
78+
Params = request.Params is { } paramsNode
79+
? JsonSerializer.Deserialize(paramsNode, McpJsonUtilities.JsonContext.Default.JsonElement)
80+
: null,
81+
};
82+
var inputResponse = await mrtrContext.RequestInputAsync(inputRequest, cancellationToken).ConfigureAwait(false);
83+
84+
return new JsonRpcResponse
85+
{
86+
Id = request.Id,
87+
Result = JsonSerializer.SerializeToNode(inputResponse.RawValue, McpJsonUtilities.JsonContext.Default.JsonElement),
88+
};
89+
}
5490
}

0 commit comments

Comments
 (0)