Skip to content

Commit b0f1e30

Browse files
halter73Copilot
andcommitted
Address tarekgh PR review feedback
- McpSessionHandler.SendRequestAsync no longer double-wraps SendToRelatedTransportAsync in _outgoingMessageFilter and no longer duplicates the per-request logging that SendToRelatedTransportAsync already emits. Restores main's once-per-send semantics. - McpClientImpl.ResolveInputRequestAsync gracefully handles roots/list InputRequests with no params (ListRootsRequestParams is optional per spec) by falling back to a default instance, matching the server-side resolver. - Rename local var (McpClientImpl) and parameter (McpServerImpl.SerializeInputRequiredResult) from PascalCase 'InputRequiredResult' to camelCase 'inputRequiredResult'. - StreamableHttpHandler.ValidateProtocolVersionHeader restored to private static (uses only a const and a static field; no instance state). - Tighten InputRequiredResult XML doc to note that this SDK currently only wires the MRTR interceptor into tools/call, even though SEP-2322 defines the wire format for prompts/get and resources/read too. - Tighten outgoing- and incoming-filter tests (AddOutgoingMessageFilter_Sees_Responses_Notifications_And_Requests, OutgoingFilter_SeesResponsesAndRequests, AddIncomingMessageFilter_Intercepts_Request_Messages, and AddIncomingMessageFilter_Multiple_Filters_Execute_In_Order) from substring/Contains/IndexOf checks to strict per-category counts. The substring assertions passed even when SendRequestAsync invoked the outgoing filter twice per request, so the regression went undetected; the new counts catch it (sampling/createMessage and tool-call response counts double when the bug is present). The symmetric incoming-side tightening guards against an analogous future regression on the receive pipeline. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5bf84d2 commit b0f1e30

7 files changed

Lines changed: 47 additions & 32 deletions

File tree

src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ internal static Task RunSessionAsync(HttpContext httpContext, McpServer session,
548548
/// Validates the MCP-Protocol-Version header if present. A missing header is allowed for backwards compatibility,
549549
/// but an invalid or unsupported value must be rejected with 400 Bad Request per the MCP spec.
550550
/// </summary>
551-
private bool ValidateProtocolVersionHeader(HttpContext context, out string? errorMessage)
551+
private static bool ValidateProtocolVersionHeader(HttpContext context, out string? errorMessage)
552552
{
553553
var protocolVersionHeader = context.Request.Headers[McpProtocolVersionHeaderName].ToString();
554554
if (!string.IsNullOrEmpty(protocolVersionHeader) &&

src/ModelContextProtocol.Core/Client/McpClientImpl.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,8 @@ private async Task<InputResponse> ResolveInputRequestAsync(InputRequest inputReq
639639
case RequestMethods.RootsList:
640640
if (_options.Handlers.RootsHandler is { } rootsHandler)
641641
{
642-
var rootsParams = inputRequest.RootsParams
643-
?? throw new McpException($"Failed to deserialize roots parameters from MRTR input request.");
642+
// ListRootsRequest params are optional per the spec, so fall back to an empty params instance.
643+
var rootsParams = inputRequest.RootsParams ?? new ListRootsRequestParams();
644644
var result = await rootsHandler(rootsParams, cancellationToken).ConfigureAwait(false);
645645
return InputResponse.FromRootsResult(result);
646646
}
@@ -859,10 +859,10 @@ request.Params is System.Text.Json.Nodes.JsonObject paramsObjForHeaders &&
859859
{
860860
WarnIfInputRequiredResultOnNonMrtrSession(request.Method);
861861

862-
var InputRequiredResult = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.JsonContext.Default.InputRequiredResult)
862+
var inputRequiredResult = JsonSerializer.Deserialize(response.Result, McpJsonUtilities.JsonContext.Default.InputRequiredResult)
863863
?? throw new JsonException("Failed to deserialize InputRequiredResult.");
864864

865-
if (InputRequiredResult.InputRequests is { Count: > 0 } inputRequests)
865+
if (inputRequiredResult.InputRequests is { Count: > 0 } inputRequests)
866866
{
867867
IDictionary<string, InputResponse> inputResponses =
868868
await ResolveInputRequestsAsync(inputRequests, cancellationToken).ConfigureAwait(false);
@@ -873,18 +873,18 @@ request.Params is System.Text.Json.Nodes.JsonObject paramsObjForHeaders &&
873873
paramsObj["inputResponses"] = JsonSerializer.SerializeToNode(
874874
inputResponses, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse);
875875

876-
if (InputRequiredResult.RequestState is { } requestState)
876+
if (inputRequiredResult.RequestState is { } requestState)
877877
{
878878
paramsObj["requestState"] = requestState;
879879
}
880880

881881
request = new JsonRpcRequest { Method = request.Method, Params = paramsObj, Context = request.Context };
882882
}
883-
else if (InputRequiredResult.RequestState is not null)
883+
else if (inputRequiredResult.RequestState is not null)
884884
{
885885
// No input requests but has requestState (e.g., load shedding) — just retry with state.
886886
var paramsObj = request.Params?.DeepClone() as JsonObject ?? new JsonObject();
887-
paramsObj["requestState"] = InputRequiredResult.RequestState;
887+
paramsObj["requestState"] = inputRequiredResult.RequestState;
888888
paramsObj.Remove("inputResponses");
889889

890890
request = new JsonRpcRequest { Method = request.Method, Params = paramsObj, Context = request.Context };

src/ModelContextProtocol.Core/McpSessionHandler.cs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,7 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
591591
AddTags(ref tags, activity, request, method, target);
592592
}
593593

594-
if (_logger.IsEnabled(LogLevel.Trace))
595-
{
596-
LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
597-
}
598-
else
599-
{
600-
LogSendingRequest(EndpointName, request.Method);
601-
}
602-
603-
await _outgoingMessageFilter(SendToRelatedTransportAsync)(request, cancellationToken).ConfigureAwait(false);
594+
await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false);
604595

605596
// Now that the request has been sent, register for cancellation. If we registered before,
606597
// a cancellation request could arrive before the server knew about that request ID, in which

src/ModelContextProtocol.Core/Protocol/InputRequiredResult.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ namespace ModelContextProtocol.Protocol;
99
/// </summary>
1010
/// <remarks>
1111
/// <para>
12-
/// An <see cref="InputRequiredResult"/> is returned in response to a client-initiated request (such as
13-
/// <see cref="RequestMethods.ToolsCall"/> or <see cref="RequestMethods.PromptsGet"/>) when the server
14-
/// needs the client to fulfill one or more server-initiated requests before it can produce a final result.
12+
/// An <see cref="InputRequiredResult"/> is returned in response to a client-initiated request when
13+
/// the server needs the client to fulfill one or more server-initiated requests before it can produce
14+
/// a final result. Per SEP-2322 the wire format is valid for <see cref="RequestMethods.ToolsCall"/>,
15+
/// <see cref="RequestMethods.PromptsGet"/>, and <c>resources/read</c>, but this SDK currently only wires
16+
/// the MRTR interceptor into <see cref="RequestMethods.ToolsCall"/>; throwing
17+
/// <see cref="InputRequiredException"/> from a prompts or resources handler will surface as an internal
18+
/// error until the other methods are opted in.
1519
/// </para>
1620
/// <para>
1721
/// At least one of <see cref="InputRequests"/> or <see cref="RequestState"/> must be present.

src/ModelContextProtocol.Core/Server/McpServerImpl.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,8 +1318,8 @@ private static async Task<InputResponse> ResolveInputRequestAsync(McpServer dest
13181318
}
13191319
}
13201320

1321-
private static JsonNode? SerializeInputRequiredResult(InputRequiredResult InputRequiredResult) =>
1322-
JsonSerializer.SerializeToNode(InputRequiredResult, McpJsonUtilities.JsonContext.Default.InputRequiredResult);
1321+
private static JsonNode? SerializeInputRequiredResult(InputRequiredResult inputRequiredResult) =>
1322+
JsonSerializer.SerializeToNode(inputRequiredResult, McpJsonUtilities.JsonContext.Default.InputRequiredResult);
13231323

13241324
[LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")]
13251325
private partial void ToolCallError(string toolName, Exception exception);

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,12 @@ await client.CallToolAsync("sampling-tool",
385385
new Dictionary<string, object?> { ["prompt"] = "Hello" },
386386
cancellationToken: TestContext.Current.CancellationToken);
387387

388-
Assert.Contains("initialize-response", observedMessageTypes);
389-
Assert.Contains("tools-list-response", observedMessageTypes);
390-
Assert.Contains("tool-call-response", observedMessageTypes);
391-
Assert.Contains($"request:{RequestMethods.SamplingCreateMessage}", observedMessageTypes);
388+
// Exact counts catch regressions where the outgoing filter pipeline gets applied more than once
389+
// per outbound message (e.g., SendRequestAsync double-wrapping SendToRelatedTransportAsync).
390+
Assert.Equal(1, observedMessageTypes.Count(m => m == "initialize-response"));
391+
Assert.Equal(1, observedMessageTypes.Count(m => m == "tools-list-response"));
392+
Assert.Equal(2, observedMessageTypes.Count(m => m == "tool-call-response")); // one per CallToolAsync
393+
Assert.Equal(2, observedMessageTypes.Count(m => m == $"request:{RequestMethods.SamplingCreateMessage}")); // sampling-tool makes two SampleAsync calls
392394
}
393395

394396
[Fact]

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,14 @@ public async Task AddIncomingMessageFilter_Intercepts_Request_Messages()
8787

8888
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
8989

90-
// The message filter should intercept JsonRpcRequest messages
91-
Assert.Contains("JsonRpcRequest", messageTypes);
90+
// The message filter should intercept JsonRpcRequest messages.
91+
// Use strict counts so a regression that invokes the filter pipeline more than once per
92+
// incoming message (analogous to the SendRequestAsync double-wrap regression on the outgoing
93+
// side) would fail this test instead of slipping through Assert.Contains.
94+
// A single ListToolsAsync drives three server-bound messages: initialize (request),
95+
// notifications/initialized (notification), and tools/list (request).
96+
Assert.Equal(2, messageTypes.Count(m => m == nameof(JsonRpcRequest)));
97+
Assert.Equal(1, messageTypes.Count(m => m == nameof(JsonRpcNotification)));
9298
}
9399

94100
[Fact]
@@ -142,6 +148,13 @@ public async Task AddIncomingMessageFilter_Multiple_Filters_Execute_In_Order()
142148
Assert.True(idx1Before < idx2Before);
143149
Assert.True(idx2Before < idx2After);
144150
Assert.True(idx2After < idx1After);
151+
152+
// Verify each filter ran exactly once per incoming message (initialize + notifications/initialized + tools/list).
153+
// Strict counts catch regressions where the incoming filter pipeline gets invoked more than once per message.
154+
Assert.Equal(3, logMessages.Count(m => m == "MessageFilter1 before"));
155+
Assert.Equal(3, logMessages.Count(m => m == "MessageFilter2 before"));
156+
Assert.Equal(3, logMessages.Count(m => m == "MessageFilter2 after"));
157+
Assert.Equal(3, logMessages.Count(m => m == "MessageFilter1 after"));
145158
}
146159

147160
[Fact]
@@ -372,15 +385,20 @@ public async Task AddOutgoingMessageFilter_Sees_Responses_Notifications_And_Requ
372385
await client.CallToolAsync("sampling-tool", new Dictionary<string, object?> { ["prompt"] = "Hello" },
373386
cancellationToken: TestContext.Current.CancellationToken);
374387

388+
// Exact counts catch regressions where the outgoing filter pipeline gets applied more than once
389+
// per outbound message (e.g., SendRequestAsync double-wrapping SendToRelatedTransportAsync).
390+
Assert.Equal(1, observedMessages.Count(m => m == "initialize"));
391+
Assert.Equal(2, observedMessages.Count(m => m == "progress")); // ProgressTool sends two NotifyProgressAsync calls
392+
Assert.Equal(2, observedMessages.Count(m => m == "response")); // one tool-call response per CallToolAsync
393+
Assert.Equal(1, observedMessages.Count(m => m == $"request:{RequestMethods.SamplingCreateMessage}"));
394+
395+
// Preserve the original ordering intent: initialize first, then progress, then the final response.
375396
int initializeIndex = observedMessages.IndexOf("initialize");
376397
int progressIndex = observedMessages.IndexOf("progress");
377398
int responseIndex = observedMessages.LastIndexOf("response");
378-
int requestIndex = observedMessages.IndexOf($"request:{RequestMethods.SamplingCreateMessage}");
379399

380-
Assert.True(initializeIndex >= 0);
381400
Assert.True(progressIndex > initializeIndex);
382401
Assert.True(responseIndex > progressIndex);
383-
Assert.True(requestIndex >= 0);
384402
}
385403

386404
[Fact]

0 commit comments

Comments
 (0)