Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/ModelContextProtocol.Core/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,43 @@ protected McpClient()
/// </para>
/// </remarks>
public abstract Task<ClientCompletionDetails> Completion { get; }

/// <summary>
/// Registers one or more tool definitions in the client's tool cache, enabling the transport
/// to send <c>Mcp-Param-*</c> headers for those tools without requiring a prior <see cref="McpClient.ListToolsAsync(RequestOptions?, CancellationToken)"/> call.
/// </summary>
/// <param name="tools">The tool definitions to register.</param>
/// <remarks>
/// <para>
/// This method allows callers who already have tool schema information (e.g., from a previous session,
/// hardcoded configuration, or an out-of-band source) to provide it directly to the client. Once registered,
/// any <see cref="McpClient.CallToolAsync(string, IReadOnlyDictionary{string, object?}?, IProgress{ProgressNotificationValue}?, RequestOptions?, CancellationToken)"/>
/// call for a registered tool will automatically include <c>Mcp-Param-*</c> HTTP headers based on
/// the tool's <c>x-mcp-header</c> schema annotations, exactly as if the tool had been discovered
/// via <see cref="McpClient.ListToolsAsync(RequestOptions?, CancellationToken)"/>.
/// </para>
/// <para>
/// <b>Cache interaction behavior:</b>
/// <list type="bullet">
/// <item>Registered tools are added to the same internal tool cache used by <see cref="McpClient.ListToolsAsync(RequestOptions?, CancellationToken)"/>.</item>
/// <item>Calling <see cref="McpClient.ListToolsAsync(RequestOptions?, CancellationToken)"/> after <see cref="RegisterTools"/> preserves
/// manually registered tools — only server-discovered tools are cleared and repopulated.</item>
/// <item>If the server returns a tool with the same name as a manually registered tool, the server's
/// definition overwrites the registered one in the cache, but the tool retains its registered status
/// and will survive subsequent cache clears.</item>
Comment thread
tarekgh marked this conversation as resolved.
Outdated
/// <item>Tools can be registered at any time — before or after <see cref="McpClient.ListToolsAsync(RequestOptions?, CancellationToken)"/>,
/// and across multiple calls.</item>
/// <item>Re-registering a tool with the same name overwrites the previous definition in the cache (last write wins).</item>
/// </list>
/// </para>
/// <para>
/// Tools with invalid <c>x-mcp-header</c> annotations are rejected and not added to the cache.
Comment thread
tarekgh marked this conversation as resolved.
Outdated
/// </para>
/// </remarks>
/// <exception cref="ArgumentNullException"><paramref name="tools"/> is <see langword="null"/>.</exception>
public virtual void RegisterTools(IEnumerable<Tool> tools)
Comment thread
tarekgh marked this conversation as resolved.
Outdated
{
Throw.IfNull(tools);
Comment thread
tarekgh marked this conversation as resolved.
}

Comment thread
tarekgh marked this conversation as resolved.
}
39 changes: 38 additions & 1 deletion src/ModelContextProtocol.Core/Client/McpClientImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ internal sealed partial class McpClientImpl : McpClient
private readonly SemaphoreSlim _disposeLock = new(1, 1);
private readonly McpTaskCancellationTokenProvider? _taskCancellationTokenProvider;
private readonly ConcurrentDictionary<string, Tool> _toolCache = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, byte> _registeredToolNames = new(StringComparer.Ordinal);

private ServerCapabilities? _serverCapabilities;
private Implementation? _serverInfo;
Expand Down Expand Up @@ -72,7 +73,23 @@ internal McpClientImpl(ITransport transport, string endpointName, McpClientOptio

ToolDiscovered = tool => _toolCache[tool.Name] = tool;
ToolRejected = (tool, reason) => LogToolRejected(tool.Name, reason);
ToolCacheClearing = () => _toolCache.Clear();
ToolCacheClearing = () =>
{
if (_registeredToolNames.IsEmpty)
{
_toolCache.Clear();
return;
}

// Only remove server-discovered tools; preserve manually registered tools.
foreach (var key in _toolCache.Keys)
{
if (!_registeredToolNames.ContainsKey(key))
{
_toolCache.TryRemove(key, out _);
}
}
};
Comment thread
tarekgh marked this conversation as resolved.
}

private void RegisterHandlers(McpClientOptions options, NotificationHandlers notificationHandlers, RequestHandlers requestHandlers)
Expand Down Expand Up @@ -637,6 +654,26 @@ internal void ResumeSession(ResumeClientSessionOptions resumeOptions)
LogClientSessionResumed(_endpointName);
}

/// <inheritdoc/>
public override void RegisterTools(IEnumerable<Tool> tools)
{
Throw.IfNull(tools);

foreach (var tool in tools)
{
Throw.IfNull(tool);

if (!McpHeaderExtractor.ValidateToolSchema(tool, out var rejectionReason))
{
ToolRejected?.Invoke(tool, rejectionReason!);
continue;
}

_registeredToolNames[tool.Name] = 0;
_toolCache[tool.Name] = tool;
}
}
Comment thread
tarekgh marked this conversation as resolved.

/// <inheritdoc/>
public override Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Json;
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.AspNetCore.Tests.Utils;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Tests.Utils;
using System.Collections.Concurrent;
using System.Text.Json;

namespace ModelContextProtocol.AspNetCore.Tests;

/// <summary>
/// Tests that <see cref="McpClient.RegisterTools"/> allows sending Mcp-Param-* headers
/// without a prior <see cref="McpClient.ListToolsAsync"/> call.
/// </summary>
public class RegisterToolsHeaderTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper), IAsyncDisposable
{
private WebApplication? _app;

/// <summary>
/// Captured headers from tools/call requests, keyed by JSON-RPC request id.
/// </summary>
private readonly ConcurrentDictionary<string, Dictionary<string, string>> _capturedHeaders = new();

private async Task StartAsync()
{
Builder.Services.Configure<JsonOptions>(options =>
{
options.SerializerOptions.TypeInfoResolverChain.Add(McpJsonUtilities.DefaultOptions.TypeInfoResolver!);
});
_app = Builder.Build();

_app.MapPost("/mcp", (JsonRpcMessage message, HttpContext context) =>
{
if (message is not JsonRpcRequest request)
{
return Results.Accepted();
}

if (request.Method == "initialize")
{
return Results.Json(new JsonRpcResponse
{
Id = request.Id,
Result = JsonSerializer.SerializeToNode(new InitializeResult
{
ProtocolVersion = "DRAFT-2026-v1",
Capabilities = new() { Tools = new() },
ServerInfo = new Implementation { Name = "header-capture-test", Version = "1.0" },
}, McpJsonUtilities.DefaultOptions)
});
}

if (request.Method == "tools/call")
{
// Capture all Mcp-Param-* headers from the incoming HTTP request
var paramHeaders = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
foreach (var header in context.Request.Headers)
{
if (header.Key.StartsWith("Mcp-Param-", StringComparison.OrdinalIgnoreCase))
{
paramHeaders[header.Key] = header.Value.ToString();
}
}

_capturedHeaders[request.Id.ToString()!] = paramHeaders;

var parameters = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(CallToolRequestParams))) as CallToolRequestParams;

return Results.Json(new JsonRpcResponse
{
Id = request.Id,
Result = JsonSerializer.SerializeToNode(new CallToolResult
{
Content = [new TextContentBlock { Text = $"ok" }],
}, McpJsonUtilities.DefaultOptions),
});
}

if (request.Method == "tools/list")
{
return Results.Json(new JsonRpcResponse
{
Id = request.Id,
Result = JsonSerializer.SerializeToNode(new ListToolsResult
{
Tools = [],
}, McpJsonUtilities.DefaultOptions),
});
}

return Results.Accepted();
});

await _app.StartAsync(TestContext.Current.CancellationToken);

HttpClient.DefaultRequestHeaders.Accept.Add(new("application/json"));
HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream"));
}

public async ValueTask DisposeAsync()
{
if (_app is not null)
{
await _app.DisposeAsync();
}
base.Dispose();
}

private static Tool CreateToolWithHeaders()
{
var schemaJson = """
{
"type": "object",
"properties": {
"region": {
"type": "string",
"x-mcp-header": "Region"
},
"priority": {
"type": "integer",
"x-mcp-header": "Priority"
}
},
"required": ["region", "priority"]
}
""";

return new Tool
{
Name = "my_tool",
InputSchema = JsonDocument.Parse(schemaJson).RootElement.Clone(),
};
}

[Fact]
public async Task RegisterTools_ThenCallTool_SendsMcpParamHeaders_WithoutListToolsAsync()
{
await StartAsync();

await using var transport = new HttpClientTransport(new()
{
Endpoint = new("http://localhost:5000/mcp"),
TransportMode = HttpTransportMode.StreamableHttp,
}, HttpClient, LoggerFactory);

await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);

// Register the tool WITHOUT calling ListToolsAsync first — this is the core scenario from issue #1577
client.RegisterTools([CreateToolWithHeaders()]);

// Call the tool
var result = await client.CallToolAsync(
"my_tool",
new Dictionary<string, object?> { ["region"] = "us-west-2", ["priority"] = 42 },
cancellationToken: TestContext.Current.CancellationToken);

Assert.NotNull(result);

// Verify that Mcp-Param-* headers were captured by the server
Assert.Single(_capturedHeaders);
var headers = _capturedHeaders.Values.First();
Assert.True(headers.ContainsKey("Mcp-Param-Region"), "Expected Mcp-Param-Region header to be sent");
Assert.Equal("us-west-2", headers["Mcp-Param-Region"]);
Assert.True(headers.ContainsKey("Mcp-Param-Priority"), "Expected Mcp-Param-Priority header to be sent");
Assert.Equal("42", headers["Mcp-Param-Priority"]);
}

[Fact]
public async Task CallToolWithoutRegisterOrList_DoesNotSendMcpParamHeaders()
{
await StartAsync();

await using var transport = new HttpClientTransport(new()
{
Endpoint = new("http://localhost:5000/mcp"),
TransportMode = HttpTransportMode.StreamableHttp,
}, HttpClient, LoggerFactory);

await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);

// Call the tool without RegisterTools or ListToolsAsync — no Mcp-Param-* headers should be sent
var result = await client.CallToolAsync(
"my_tool",
new Dictionary<string, object?> { ["region"] = "us-west-2", ["priority"] = 42 },
cancellationToken: TestContext.Current.CancellationToken);

Assert.NotNull(result);

// Verify that NO Mcp-Param-* headers were sent
Assert.Single(_capturedHeaders);
var headers = _capturedHeaders.Values.First();
Assert.Empty(headers);
}

[Fact]
public async Task RegisterTools_SurvivesListToolsAsync_HeadersStillSent()
{
await StartAsync();

await using var transport = new HttpClientTransport(new()
{
Endpoint = new("http://localhost:5000/mcp"),
TransportMode = HttpTransportMode.StreamableHttp,
}, HttpClient, LoggerFactory);

await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);

// Register the tool first
client.RegisterTools([CreateToolWithHeaders()]);

// Call ListToolsAsync — server returns empty list, but registered tool should survive
await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);

// Call the registered tool — Mcp-Param-* headers should still be sent
var result = await client.CallToolAsync(
"my_tool",
new Dictionary<string, object?> { ["region"] = "eu-central-1", ["priority"] = 99 },
cancellationToken: TestContext.Current.CancellationToken);

Assert.NotNull(result);

// Verify headers were sent
Assert.Single(_capturedHeaders);
var headers = _capturedHeaders.Values.First();
Assert.True(headers.ContainsKey("Mcp-Param-Region"), "Expected Mcp-Param-Region header after ListToolsAsync");
Assert.Equal("eu-central-1", headers["Mcp-Param-Region"]);
Assert.True(headers.ContainsKey("Mcp-Param-Priority"), "Expected Mcp-Param-Priority header after ListToolsAsync");
Assert.Equal("99", headers["Mcp-Param-Priority"]);
}
}
Loading
Loading