diff --git a/README.MD b/README.MD index 61b3335c5..1997a962d 100644 --- a/README.MD +++ b/README.MD @@ -80,7 +80,7 @@ var response = await chatClient.GetResponseAsync( Here is an example of how to create an MCP server and register all tools from the current application. It includes a simple echo tool as an example (this is included in the same file here for easy of copy and paste, but it needn't be in the same file... -the employed overload of `WithTools` examines the current assembly for classes with the `McpToolType` attribute, and registers all methods with the +the employed overload of `WithTools` examines the current assembly for classes with the `McpServerToolType` attribute, and registers all methods with the `McpTool` attribute as tools.) ```csharp diff --git a/src/ModelContextProtocol/README.md b/src/ModelContextProtocol/README.md index a109addce..90a39907c 100644 --- a/src/ModelContextProtocol/README.md +++ b/src/ModelContextProtocol/README.md @@ -85,7 +85,7 @@ var response = await chatClient.GetResponseAsync( Here is an example of how to create an MCP server and register all tools from the current application. It includes a simple echo tool as an example (this is included in the same file here for easy of copy and paste, but it needn't be in the same file... -the employed overload of `WithTools` examines the current assembly for classes with the `McpToolType` attribute, and registers all methods with the +the employed overload of `WithTools` examines the current assembly for classes with the `McpServerToolType` attribute, and registers all methods with the `McpTool` attribute as tools.) ```csharp @@ -101,7 +101,7 @@ builder.Services .WithTools(); await builder.Build().RunAsync(); -[McpToolType] +[McpServerToolType] public static class EchoTool { [McpTool, Description("Echoes the message back to the client.")] diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 4c0c6b564..7f425b1fb 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -15,6 +15,8 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer { private readonly IServerTransport? _serverTransport; private readonly string _serverDescription; + private readonly EventHandler? _toolsChangedDelegate; + private volatile bool _isInitializing; /// @@ -32,17 +34,29 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? Throw.IfNull(options); _serverTransport = transport as IServerTransport; - ServerInstructions = options.ServerInstructions; + ServerOptions = options; Services = serviceProvider; _serverDescription = $"{options.ServerInfo.Name} {options.ServerInfo.Version}"; + _toolsChangedDelegate = delegate + { + _ = SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ToolListChangedNotification, + }); + }; AddNotificationHandler("notifications/initialized", _ => { + if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) + { + tools.Changed += _toolsChangedDelegate; + } + IsInitialized = true; return Task.CompletedTask; }); - SetToolsHandler(ref options); + SetToolsHandler(options); SetInitializeHandler(options); SetCompletionHandler(options); @@ -50,18 +64,15 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? SetPromptsHandler(options); SetResourcesHandler(options); SetSetLoggingLevelHandler(options); - - ServerOptions = options; } + public ServerCapabilities? ServerCapabilities { get; set; } + public ClientCapabilities? ClientCapabilities { get; set; } /// public Implementation? ClientInfo { get; set; } - /// - public string? ServerInstructions { get; set; } - /// public McpServerOptions ServerOptions { get; } @@ -111,6 +122,15 @@ public async Task StartAsync(CancellationToken cancellationToken = default) } } + protected override Task CleanupAsync() + { + if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) + { + tools.Changed -= _toolsChangedDelegate; + } + return base.CleanupAsync(); + } + private void SetPingHandler() { SetRequestHandler("ping", @@ -127,9 +147,9 @@ private void SetInitializeHandler(McpServerOptions options) return Task.FromResult(new InitializeResult() { ProtocolVersion = options.ProtocolVersion, - Instructions = ServerInstructions, + Instructions = options.ServerInstructions, ServerInfo = options.ServerInfo, - Capabilities = options.Capabilities ?? new ServerCapabilities(), + Capabilities = ServerCapabilities ?? new(), }); }); } @@ -198,7 +218,7 @@ private void SetPromptsHandler(McpServerOptions options) SetRequestHandler("prompts/get", (request, ct) => getPromptHandler(new(this, request), ct)); } - private void SetToolsHandler(ref McpServerOptions options) + private void SetToolsHandler(McpServerOptions options) { ToolsCapability? toolsCapability = options.Capabilities?.Tools; var listToolsHandler = toolsCapability?.ListToolsHandler; @@ -261,25 +281,25 @@ private void SetToolsHandler(ref McpServerOptions options) return tool.InvokeAsync(request, cancellationToken); }; - toolsCapability ??= new(); - toolsCapability.CallToolHandler = callToolHandler; - toolsCapability.ListToolsHandler = listToolsHandler; - toolsCapability.ToolCollection = tools; - toolsCapability.ListChanged = true; - - options.Capabilities ??= new(); - options.Capabilities.Tools = toolsCapability; - - tools.Changed += delegate + ServerCapabilities = new() { - _ = SendMessageAsync(new JsonRpcNotification() + Experimental = options.Capabilities?.Experimental, + Logging = options.Capabilities?.Logging, + Prompts = options.Capabilities?.Prompts, + Resources = options.Capabilities?.Resources, + Tools = new() { - Method = NotificationMethods.ToolListChangedNotification, - }); + ListToolsHandler = listToolsHandler, + CallToolHandler = callToolHandler, + ToolCollection = tools, + ListChanged = true, + } }; } else { + ServerCapabilities = options.Capabilities; + if (toolsCapability is null) { // No tools, and no tools capability was declared, so nothing to do. diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 175a60be3..90824c586 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -359,7 +359,7 @@ protected void SetRequestHandler(string method, Func /// - protected async Task CleanupAsync() + protected virtual async Task CleanupAsync() { if (_isDisposed) return; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 43d513128..03d0e9933 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -38,8 +38,8 @@ private async Task CreateMcpClientForServer() { await _server.StartAsync(TestContext.Current.CancellationToken); - var stdin = new StreamReader(_serverToClientPipe.Reader.AsStream()); - var stdout = new StreamWriter(_clientToServerPipe.Writer.AsStream()); + var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); + var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); var serverConfig = new McpServerConfig() { @@ -50,7 +50,7 @@ private async Task CreateMcpClientForServer() return await McpClientFactory.CreateAsync( serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(stdin, stdout), + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index dbf135363..3ae313010 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -11,38 +11,45 @@ using Microsoft.Extensions.AI; using System.Threading.Channels; using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Tests.Utils; +using Microsoft.Extensions.Logging; namespace ModelContextProtocol.Tests.Configuration; -public class McpServerBuilderExtensionsToolsTests : IAsyncDisposable +public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable { - private Pipe _clientToServerPipe = new(); - private Pipe _serverToClientPipe = new(); + private readonly Pipe _clientToServerPipe = new(); + private readonly Pipe _serverToClientPipe = new(); + private readonly ServiceProvider _serviceProvider; private readonly IMcpServerBuilder _builder; private readonly IMcpServer _server; - public McpServerBuilderExtensionsToolsTests() + public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) { ServiceCollection sc = new(); + sc.AddSingleton(LoggerFactory); sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); sc.AddSingleton(new ObjectWithId()); _builder = sc.AddMcpServer().WithTools(); - _server = sc.BuildServiceProvider().GetRequiredService(); + _serviceProvider = sc.BuildServiceProvider(); + _server = _serviceProvider.GetRequiredService(); } public ValueTask DisposeAsync() { _clientToServerPipe.Writer.Complete(); _serverToClientPipe.Writer.Complete(); - return _server.DisposeAsync(); + return _serviceProvider.DisposeAsync(); } private async Task CreateMcpClientForServer() { await _server.StartAsync(TestContext.Current.CancellationToken); - var stdin = new StreamReader(_serverToClientPipe.Reader.AsStream()); - var stdout = new StreamWriter(_clientToServerPipe.Writer.AsStream()); + var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); + var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); var serverConfig = new McpServerConfig() { @@ -53,7 +60,7 @@ private async Task CreateMcpClientForServer() return await McpClientFactory.CreateAsync( serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(stdin, stdout), + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), cancellationToken: TestContext.Current.CancellationToken); } @@ -86,6 +93,63 @@ public async Task Can_List_Registered_Tools() Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); } + + [Fact] + public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_Tools() + { + var options = _serviceProvider.GetRequiredService>().Value; + var loggerFactory = _serviceProvider.GetRequiredService(); + + for (int i = 0; i < 2; i++) + { + var stdinPipe = new Pipe(); + var stdoutPipe = new Pipe(); + + try + { + var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); + var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); + + await server.StartAsync(TestContext.Current.CancellationToken); + + var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); + var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); + + var serverConfig = new McpServerConfig() + { + Id = $"TestServer_{i}", + Name = $"TestServer_{i}", + TransportType = "ignored", + }; + + var client = await McpClientFactory.CreateAsync( + serverConfig, + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), + cancellationToken: TestContext.Current.CancellationToken); + + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(11, tools.Count); + + McpClientTool echoTool = tools.First(t => t.Name == "Echo"); + Assert.Equal("Echo", echoTool.Name); + Assert.Equal("Echoes the input back to the client.", echoTool.Description); + Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString()); + Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind); + Assert.Equal("the echoes message", echoTool.JsonSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString()); + Assert.Equal(1, echoTool.JsonSchema.GetProperty("required").GetArrayLength()); + + McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo"); + Assert.Equal("double_echo", doubleEchoTool.Name); + Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); + } + finally + { + stdinPipe.Writer.Complete(); + stdoutPipe.Writer.Complete(); + } + } + } + [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs index a8014e3ab..49c7ca7ad 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs @@ -9,16 +9,16 @@ namespace ModelContextProtocol.Tests.Transport; internal sealed class StreamClientTransport : TransportBase, IClientTransport { private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions; - private Task? _readTask; - private CancellationTokenSource _shutdownCts = new CancellationTokenSource(); - private readonly TextReader _stdin; - private readonly TextWriter _stdout; + private readonly Task? _readTask; + private readonly CancellationTokenSource _shutdownCts = new CancellationTokenSource(); + private readonly TextReader _serverStdoutReader; + private readonly TextWriter _serverStdinWriter; - public StreamClientTransport(TextReader stdin, TextWriter stdout) + public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdoutReader) : base(NullLoggerFactory.Instance) { - _stdin = stdin; - _stdout = stdout; + _serverStdoutReader = serverStdoutReader; + _serverStdinWriter = serverStdinWriter; _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); SetConnected(true); } @@ -31,13 +31,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio messageWithId.Id.ToString() : "(no id)"; - await _stdout.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false); - await _stdout.FlushAsync(cancellationToken).ConfigureAwait(false); + await _serverStdinWriter.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false); + await _serverStdinWriter.FlushAsync(cancellationToken).ConfigureAwait(false); } private async Task ReadMessagesAsync(CancellationToken cancellationToken) { - while (await _stdin.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line) + while (await _serverStdoutReader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line) { if (!string.IsNullOrWhiteSpace(line)) {