diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs index 4733fce16..e11a4ab9c 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs @@ -14,32 +14,51 @@ public static partial class McpServerBuilderExtensions { private const string RequiresUnreferencedCodeMessage = "This method requires dynamic lookup of method metadata and might not work in Native AOT."; - /// - /// Adds a tool to the server. - /// + /// Adds instances to the service collection backing . /// The tool type. /// The builder instance. /// is . - public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] TTool>( + /// + /// This method discovers all instance and static methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. + /// + public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods | + DynamicallyAccessedMemberTypes.PublicConstructors)] TTool>( this IMcpServerBuilder builder) { Throw.IfNull(builder); - foreach (var toolMethod in GetToolMethods(typeof(TTool))) + foreach (var toolMethod in typeof(TTool).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) { - builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); + if (toolMethod.GetCustomAttribute() is not null) + { + if (toolMethod.IsStatic) + { + builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); + } + else + { + builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, typeof(TTool), services: services)); + } + } } return builder; } - /// - /// Adds tools to the server. - /// + /// Adds instances to the service collection backing . /// The builder instance. /// Types with marked methods to add as tools to the server. /// is . /// is . + /// + /// This method discovers all instance and static methods (public and non-public) on the specified + /// types, where the methods are attributed as , and adds an + /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. + /// [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params IEnumerable toolTypes) { @@ -50,13 +69,23 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params { if (toolType is not null) { - foreach (var toolMethod in GetToolMethods(toolType)) + foreach (var method in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) { - builder.Services.AddSingleton(services => McpServerTool.Create(toolMethod, services: services)); + if (method.GetCustomAttribute() is not null) + { + if (method.IsStatic) + { + builder.Services.AddSingleton(services => McpServerTool.Create(method, services: services)); + } + else + { + builder.Services.AddSingleton(services => McpServerTool.Create(method, toolType, services: services)); + } + } } } } - + return builder; } @@ -78,10 +107,4 @@ from t in toolAssembly.GetTypes() where t.GetCustomAttribute() is not null select t); } - - private static IEnumerable GetToolMethods( - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicMethods | DynamicallyAccessedMemberTypes.NonPublicMethods)] Type toolType) => - from method in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static) - where method.GetCustomAttribute() is not null - select method; } diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index d3fbd93c5..ff3f92887 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; +using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Text.Json; @@ -13,7 +14,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool { /// Key used temporarily for flowing request context into an AIFunction. /// This will be replaced with use of AIFunctionArguments.Context. - private const string RequestContextKey = "__temporary_RequestContext"; + internal const string RequestContextKey = "__temporary_RequestContext"; /// /// Creates an instance for a method, specified via a instance. @@ -48,7 +49,27 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool // AIFunctionFactory, delete the TemporaryXx types, and fix-up the mechanism by // which the arguments are passed. - return Create(TemporaryAIFunctionFactory.Create(method, target, new TemporaryAIFunctionFactoryOptions() + return Create(TemporaryAIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, name, description, services))); + } + + /// + /// Creates an instance for a method, specified via a instance. + /// + public static new AIFunctionMcpServerTool Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + string? name = null, + string? description = null, + IServiceProvider? services = null) + { + Throw.IfNull(method); + + return Create(TemporaryAIFunctionFactory.Create(method, targetType, CreateAIFunctionFactoryOptions(method, name, description, services))); + } + + private static TemporaryAIFunctionFactoryOptions CreateAIFunctionFactoryOptions( + MethodInfo method, string? name, string? description, IServiceProvider? services) => + new TemporaryAIFunctionFactoryOptions() { Name = name ?? method.GetCustomAttribute()?.Name, Description = description, @@ -115,8 +136,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool return null; } }, - })); - } + }; /// Creates an that wraps the specified . public static new AIFunctionMcpServerTool Create(AIFunction function) diff --git a/src/ModelContextProtocol/Server/McpServerTool.cs b/src/ModelContextProtocol/Server/McpServerTool.cs index f6122764c..c262df75a 100644 --- a/src/ModelContextProtocol/Server/McpServerTool.cs +++ b/src/ModelContextProtocol/Server/McpServerTool.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.AI; using ModelContextProtocol.Protocol.Types; using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; using System.Reflection; namespace ModelContextProtocol.Server; @@ -40,7 +41,7 @@ public abstract Task InvokeAsync( /// /// /// Optional services used in the construction of the . These services will be - /// used to determine which parameters should be satisifed from dependency injection, and so what services + /// used to determine which parameters should be satisifed from dependency injection; what services /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. /// /// The created for invoking . @@ -68,7 +69,7 @@ public static McpServerTool Create( /// /// /// Optional services used in the construction of the . These services will be - /// used to determine which parameters should be satisifed from dependency injection, and so what services + /// used to determine which parameters should be satisifed from dependency injection; what services /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. /// /// The created for invoking . @@ -82,6 +83,43 @@ public static McpServerTool Create( IServiceProvider? services = null) => AIFunctionMcpServerTool.Create(method, target, name, description, services); + /// + /// Creates an instance for a method, specified via an for + /// and instance method, along with a representing the type of the target object to + /// instantiate each time the method is invoked. + /// + /// The instance method to be represented via the created . + /// + /// The to construct an instance of on which to invoke when + /// the resulting is invoked. If services are provided, + /// ActivatorUtilities.CreateInstance will be used to construct the instance using those services; otherwise, + /// is used, utilizing the type's public parameterless constructor. + /// If an instance can't be constructed, an exception is thrown during the function's invocation. + /// + /// + /// The name to use for the . If , but an + /// is applied to , the name from the attribute will be used. If that's not present, the name based + /// on 's name will be used. + /// + /// + /// The description to use for the . If , but a + /// is applied to , the description from that attribute will be used. + /// + /// + /// Optional services used in the construction of the . These services will be + /// used to determine which parameters should be satisifed from dependency injection; what services + /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. + /// + /// The created for invoking . + /// is . + public static McpServerTool Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + string? name = null, + string? description = null, + IServiceProvider? services = null) => + AIFunctionMcpServerTool.Create(method, targetType, name, description, services); + /// Creates an that wraps the specified . /// The function to wrap. /// is . diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs index bf0ae8ae9..67e7a99b1 100644 --- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs @@ -1,10 +1,15 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using System.Collections.Concurrent; using System.ComponentModel; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + #if !NET using System.Linq; #endif @@ -110,6 +115,42 @@ public static AIFunction Create(MethodInfo method, object? target, TemporaryAIFu return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions); } + /// + /// Creates an instance for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The instance method to be represented via the created . + /// + /// The to construct an instance of on which to invoke when + /// the resulting is invoked. If services are provided, + /// ActivatorUtilities.CreateInstance will be used to construct the instance using those services; otherwise, + /// is used, utilizing the type's public parameterless constructor. + /// If an instance can't be constructed, an exception is thrown during the function's invocation. + /// + /// Metadata to use to override defaults inferred from . + /// The created for invoking . + /// + /// + /// Return values are serialized to using 's + /// . Arguments that are not already of the expected type are + /// marshaled to the expected type via JSON and using 's + /// . If the argument is a , + /// , or , it is deserialized directly. If the argument is anything else unknown, + /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. + /// + /// + /// is . + public static AIFunction Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + TemporaryAIFunctionFactoryOptions? options = null) + { + Throw.IfNull(method); + Throw.IfNull(targetType); + + return ReflectionAIFunction.Build(method, targetType, options ?? _defaultOptions); + } + /// /// Creates an instance for a method, specified via an instance /// and an optional target object if the method is an instance method. @@ -176,6 +217,32 @@ public static ReflectionAIFunction Build(MethodInfo method, object? target, Temp return new(functionDescriptor, target, options); } + public static ReflectionAIFunction Build( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + TemporaryAIFunctionFactoryOptions options) + { + Throw.IfNull(method); + + if (method.ContainsGenericParameters) + { + throw new ArgumentException("Open generic methods are not supported", nameof(method)); + } + + if (method.IsStatic) + { + throw new ArgumentException("The method must be an instance method.", nameof(method)); + } + + if (method.DeclaringType is { } declaringType && + !declaringType.IsAssignableFrom(targetType)) + { + throw new ArgumentException("The target type must be assignable to the method's declaring type.", nameof(targetType)); + } + + return new(ReflectionAIFunctionDescriptor.GetOrCreate(method, options), targetType, options); + } + private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, TemporaryAIFunctionFactoryOptions options) { FunctionDescriptor = functionDescriptor; @@ -183,8 +250,20 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, AdditionalProperties = options.AdditionalProperties ?? new Dictionary(); } + private ReflectionAIFunction( + ReflectionAIFunctionDescriptor functionDescriptor, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + TemporaryAIFunctionFactoryOptions options) + { + FunctionDescriptor = functionDescriptor; + TargetType = targetType; + AdditionalProperties = options.AdditionalProperties ?? new Dictionary(); + } + public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } public object? Target { get; } + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] + public Type? TargetType { get; } public override IReadOnlyDictionary AdditionalProperties { get; } public override string Name => FunctionDescriptor.Name; public override string Description => FunctionDescriptor.Description; @@ -192,22 +271,59 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; - protected override Task InvokeCoreAsync( + protected override async Task InvokeCoreAsync( IEnumerable> arguments, CancellationToken cancellationToken) { - var paramMarshallers = FunctionDescriptor.ParameterMarshallers; - object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; - Dictionary argumentsDictionary = arguments.ToDictionary(); - for (int i = 0; i < args.Length; i++) + bool disposeTarget = false; + object? target = Target; + try { - args[i] = paramMarshallers[i](argumentsDictionary, cancellationToken); - } + if (TargetType is { } targetType) + { + Debug.Assert(target is null, "Expected target to be null when we have a non-null target type"); + Debug.Assert(!FunctionDescriptor.Method.IsStatic, "Expected an instance method"); + + if (argumentsDictionary.TryGetValue(AIFunctionMcpServerTool.RequestContextKey, out object? value) && + value is RequestContext requestContext && + requestContext.Server?.Services is { } services) + { + target = ActivatorUtilities.CreateInstance(services, targetType!); + } + else + { + target = Activator.CreateInstance(targetType); + } + + disposeTarget = true; + } + var paramMarshallers = FunctionDescriptor.ParameterMarshallers; + object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; - return FunctionDescriptor.ReturnParameterMarshaller( - ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); + for (int i = 0; i < args.Length; i++) + { + args[i] = paramMarshallers[i](argumentsDictionary, cancellationToken); + } + + return await FunctionDescriptor.ReturnParameterMarshaller( + ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken).ConfigureAwait(false); + } + finally + { + if (disposeTarget) + { + if (target is IAsyncDisposable ad) + { + await ad.DisposeAsync().ConfigureAwait(false); + } + else if (target is IDisposable d) + { + d.Dispose(); + } + } + } } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 3ecf9a69f..dbf135363 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -25,6 +25,7 @@ public McpServerBuilderExtensionsToolsTests() { ServiceCollection sc = new(); sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + sc.AddSingleton(new ObjectWithId()); _builder = sc.AddMcpServer().WithTools(); _server = sc.BuildServiceProvider().GetRequiredService(); } @@ -70,7 +71,7 @@ public async Task Can_List_Registered_Tools() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(10, tools.Count); + Assert.Equal(11, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -91,7 +92,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(10, tools.Count); + Assert.Equal(11, tools.Count); Channel listChanged = Channel.CreateUnbounded(); client.AddNotificationHandler("notifications/tools/list_changed", notification => @@ -111,7 +112,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(11, tools.Count); + Assert.Equal(12, tools.Count); Assert.Contains(tools, t => t.Name == "NewTool"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -120,7 +121,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(10, tools.Count); + Assert.Equal(11, tools.Count); Assert.DoesNotContain(tools, t => t.Name == "NewTool"); } @@ -224,6 +225,35 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() Assert.Equal("text", result.Content[0].Type); } + [Fact] + public async Task Can_Call_Registered_Tool_With_Instance_Method() + { + IMcpClient client = await CreateMcpClientForServer(); + + string[][] parts = new string[2][]; + for (int i = 0; i < 2; i++) + { + var result = await client.CallToolAsync( + nameof(EchoTool.GetCtorParameter), + cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.NotEmpty(result.Content); + + parts[i] = result.Content[0].Text?.Split(':') ?? []; + Assert.Equal(2, parts[i].Length); + } + + string random1 = parts[0][0]; + string random2 = parts[1][0]; + Assert.NotEqual(random1, random2); + + string id1 = parts[0][1]; + string id2 = parts[1][1]; + Assert.Equal(id1, id2); + } + [Fact] public async Task Returns_IsError_Content_When_Tool_Fails() { @@ -334,8 +364,10 @@ public void Register_Tools_From_Multiple_Sources() } [McpServerToolType] - public sealed class EchoTool + public sealed class EchoTool(ObjectWithId objectFromDI) { + private string _randomValue = Guid.NewGuid().ToString("N"); + [McpServerTool, Description("Echoes the input back to the client.")] public static string Echo([Description("the echoes message")] string message) { @@ -395,6 +427,9 @@ public static string EchoComplex(ComplexObject complex) { return complex.Name!; } + + [McpServerTool] + public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}"; } [McpServerToolType] @@ -421,4 +456,9 @@ public class ComplexObject public string? Name { get; set; } public int Age { get; set; } } + + public class ObjectWithId + { + public string Id { get; set; } = Guid.NewGuid().ToString("N"); + } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 49a823195..3f066dd5c 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -4,6 +4,7 @@ using Moq; using System.Reflection; using System.Text.Json; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Tests.Server; @@ -89,5 +90,120 @@ public async Task SupportsOptionalServiceFromDI() Assert.Equal("42", result.Content[0].Text); } + [Fact] + public async Task SupportsDisposingInstantiatedDisposableTargets() + { + McpServerTool tool1 = McpServerTool.Create( + typeof(DisposableToolType).GetMethod(nameof(DisposableToolType.InstanceMethod))!, + typeof(DisposableToolType)); + + var result = await tool1.InvokeAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("""{"disposals":1}""", result.Content[0].Text); + } + + [Fact] + public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() + { + McpServerTool tool1 = McpServerTool.Create( + typeof(AsyncDisposableToolType).GetMethod(nameof(AsyncDisposableToolType.InstanceMethod))!, + typeof(AsyncDisposableToolType)); + + var result = await tool1.InvokeAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("""{"asyncDisposals":1}""", result.Content[0].Text); + } + + [Fact] + public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposableTargets() + { + McpServerTool tool1 = McpServerTool.Create( + typeof(AsyncDisposableAndDisposableToolType).GetMethod(nameof(AsyncDisposableAndDisposableToolType.InstanceMethod))!, + typeof(AsyncDisposableAndDisposableToolType)); + + var result = await tool1.InvokeAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", result.Content[0].Text); + } + private sealed class MyService; + + private class DisposableToolType : IDisposable + { + public int Disposals { get; private set; } + + public void Dispose() + { + Disposals++; + } + + public object InstanceMethod() + { + if (Disposals != 0) + { + throw new InvalidOperationException("Dispose was called"); + } + + return this; + } + } + + private class AsyncDisposableToolType : IAsyncDisposable + { + public int AsyncDisposals { get; private set; } + + public ValueTask DisposeAsync() + { + AsyncDisposals++; + return default; + } + + public object InstanceMethod() + { + if (AsyncDisposals != 0) + { + throw new InvalidOperationException("DisposeAsync was called"); + } + + return this; + } + } + + private class AsyncDisposableAndDisposableToolType : IAsyncDisposable, IDisposable + { + [JsonPropertyOrder(0)] + public int AsyncDisposals { get; private set; } + + [JsonPropertyOrder(1)] + public int Disposals { get; private set; } + + public void Dispose() + { + Disposals++; + } + + public ValueTask DisposeAsync() + { + AsyncDisposals++; + return default; + } + + public object InstanceMethod() + { + if (Disposals != 0) + { + throw new InvalidOperationException("Dispose was called"); + } + + if (AsyncDisposals != 0) + { + throw new InvalidOperationException("DisposeAsync was called"); + } + + return this; + } + } }