diff --git a/samples/EverythingServer/Program.cs b/samples/EverythingServer/Program.cs index 5041dc875..d46e40b57 100644 --- a/samples/EverythingServer/Program.cs +++ b/samples/EverythingServer/Program.cs @@ -46,7 +46,7 @@ { subscriptions.Add(uri); - await ctx.Server.RequestSamplingAsync([ + await ctx.Server.SampleAsync([ new ChatMessage(ChatRole.System, "You are a helpful test server"), new ChatMessage(ChatRole.User, $"Resource {uri}, context: A new subscription was started"), ], diff --git a/samples/EverythingServer/Tools/SampleLlmTool.cs b/samples/EverythingServer/Tools/SampleLlmTool.cs index 43e7e7b3f..1b49cd294 100644 --- a/samples/EverythingServer/Tools/SampleLlmTool.cs +++ b/samples/EverythingServer/Tools/SampleLlmTool.cs @@ -15,7 +15,7 @@ public static async Task SampleLLM( CancellationToken cancellationToken) { var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens); - var sampleResult = await server.RequestSamplingAsync(samplingParams, cancellationToken); + var sampleResult = await server.SampleAsync(samplingParams, cancellationToken); return $"LLM sampling result: {sampleResult.Content.Text}"; } diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index fc405d5a8..964f7b31a 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -18,7 +18,7 @@ public static async Task SampleLLM( CancellationToken cancellationToken) { var samplingParams = CreateRequestSamplingParams(prompt ?? string.Empty, "sampleLLM", maxTokens); - var sampleResult = await thisServer.RequestSamplingAsync(samplingParams, cancellationToken); + var sampleResult = await thisServer.SampleAsync(samplingParams, cancellationToken); return $"LLM sampling result: {sampleResult.Content.Text}"; } diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 957eb7372..8dad491ec 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -76,6 +76,20 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, McpJsonUtilities.JsonContext.Default.ListRootsResult); } + + if (capabilities.Elicitation is { } elicitationCapability) + { + if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) + { + throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); + } + + RequestHandlers.Set( + RequestMethods.ElicitationCreate, + (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult); + } } } diff --git a/src/ModelContextProtocol/McpJsonUtilities.cs b/src/ModelContextProtocol/McpJsonUtilities.cs index 162bc2343..fda08f76b 100644 --- a/src/ModelContextProtocol/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/McpJsonUtilities.cs @@ -96,6 +96,8 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(CompleteResult))] [JsonSerializable(typeof(CreateMessageRequestParams))] [JsonSerializable(typeof(CreateMessageResult))] + [JsonSerializable(typeof(ElicitRequestParams))] + [JsonSerializable(typeof(ElicitResult))] [JsonSerializable(typeof(EmptyResult))] [JsonSerializable(typeof(GetPromptRequestParams))] [JsonSerializable(typeof(GetPromptResult))] diff --git a/src/ModelContextProtocol/Protocol/ClientCapabilities.cs b/src/ModelContextProtocol/Protocol/ClientCapabilities.cs index 050d1336f..88680e2a6 100644 --- a/src/ModelContextProtocol/Protocol/ClientCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/ClientCapabilities.cs @@ -58,6 +58,13 @@ public class ClientCapabilities [JsonPropertyName("sampling")] public SamplingCapability? Sampling { get; set; } + /// + /// Gets or sets the client's elicitation capability, which indicates whether the client + /// supports elicitation of additional information from the user on behalf of the server. + /// + [JsonPropertyName("elicitation")] + public ElicitationCapability? Elicitation { get; set; } + /// Gets or sets notification handlers to register with the client. /// /// diff --git a/src/ModelContextProtocol/Protocol/ElicitRequestParams.cs b/src/ModelContextProtocol/Protocol/ElicitRequestParams.cs new file mode 100644 index 000000000..2cee4d1fa --- /dev/null +++ b/src/ModelContextProtocol/Protocol/ElicitRequestParams.cs @@ -0,0 +1,230 @@ +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents a message issued from the server to elicit additional information from the user via the client. +/// +public class ElicitRequestParams +{ + /// + /// Gets or sets the message to present to the user. + /// + [JsonPropertyName("message")] + public string Message { get; set; } = string.Empty; + + /// + /// Gets or sets the requested schema. + /// + /// + /// May be one of , , , or . + /// + [JsonPropertyName("requestedSchema")] + [field: MaybeNull] + public RequestSchema RequestedSchema + { + get => field ??= new RequestSchema(); + set => field = value; + } + + /// Represents a request schema used in an elicitation request. + public class RequestSchema + { + /// Gets the type of the schema. + /// This is always "object". + [JsonPropertyName("type")] + public string Type => "object"; + + /// Gets or sets the properties of the schema. + [JsonPropertyName("properties")] + [field: MaybeNull] + public IDictionary Properties + { + get => field ??= new Dictionary(); + set + { + Throw.IfNull(value); + field = value; + } + } + + /// Gets or sets the required properties of the schema. + [JsonPropertyName("required")] + public IList? Required { get; set; } + } + + + /// + /// Represents restricted subset of JSON Schema: + /// , , , or . + /// + [JsonDerivedType(typeof(BooleanSchema))] + [JsonDerivedType(typeof(EnumSchema))] + [JsonDerivedType(typeof(NumberSchema))] + [JsonDerivedType(typeof(StringSchema))] + public abstract class PrimitiveSchemaDefinition + { + protected private PrimitiveSchemaDefinition() + { + } + } + + /// Represents a schema for a string type. + public sealed class StringSchema : PrimitiveSchemaDefinition + { + /// Gets the type of the schema. + /// This is always "string". + [JsonPropertyName("type")] + public string Type => "string"; + + /// Gets or sets a title for the string. + [JsonPropertyName("title")] + public string? Title { get; set; } + + /// Gets or sets a description for the string. + [JsonPropertyName("description")] + public string? Description { get; set; } + + /// Gets or sets the minimum length for the string. + [JsonPropertyName("minLength")] + public int? MinLength + { + get => field; + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Minimum length cannot be negative."); + } + + field = value; + } + } + + /// Gets or sets the maximum length for the string. + [JsonPropertyName("maxLength")] + public int? MaxLength + { + get => field; + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Maximum length cannot be negative."); + } + + field = value; + } + } + + /// Gets or sets a specific format for the string ("email", "uri", "date", or "date-time"). + [JsonPropertyName("format")] + public string? Format + { + get => field; + set + { + if (value is not (null or "email" or "uri" or "date" or "date-time")) + { + throw new ArgumentException("Format must be 'email', 'uri', 'date', or 'date-time'.", nameof(value)); + } + + field = value; + } + } + } + + /// Represents a schema for a number or integer type. + public sealed class NumberSchema : PrimitiveSchemaDefinition + { + /// Gets the type of the schema. + /// This should be "number" or "integer". + [JsonPropertyName("type")] + [field: MaybeNull] + public string Type + { + get => field ??= "number"; + set + { + if (value is not ("number" or "integer")) + { + throw new ArgumentException("Type must be 'number' or 'integer'.", nameof(value)); + } + + field = value; + } + } + + /// Gets or sets a title for the number input. + [JsonPropertyName("title")] + public string? Title { get; set; } + + /// Gets or sets a description for the number input. + [JsonPropertyName("description")] + public string? Description { get; set; } + + /// Gets or sets the minimum allowed value. + [JsonPropertyName("minimum")] + public double? Minimum { get; set; } + + /// Gets or sets the maximum allowed value. + [JsonPropertyName("maximum")] + public double? Maximum { get; set; } + } + + /// Represents a schema for a Boolean type. + public sealed class BooleanSchema : PrimitiveSchemaDefinition + { + /// Gets the type of the schema. + /// This is always "boolean". + [JsonPropertyName("type")] + public string Type => "boolean"; + + /// Gets or sets a title for the Boolean. + [JsonPropertyName("title")] + public string? Title { get; set; } + + /// Gets or sets a description for the Boolean. + [JsonPropertyName("description")] + public string? Description { get; set; } + + /// Gets or sets the default value for the Boolean. + [JsonPropertyName("default")] + public bool? Default { get; set; } + } + + /// Represents a schema for an enum type. + public sealed class EnumSchema : PrimitiveSchemaDefinition + { + /// Gets the type of the schema. + /// This is always "string". + [JsonPropertyName("type")] + public string Type => "string"; + + /// Gets or sets a title for the enum. + [JsonPropertyName("title")] + public string? Title { get; set; } + + /// Gets or sets a description for the enum. + [JsonPropertyName("description")] + public string? Description { get; set; } + + /// Gets or sets the list of allowed string values for the enum. + [JsonPropertyName("enum")] + [field: MaybeNull] + public IList Enum + { + get => field ??= []; + set + { + Throw.IfNull(value); + field = value; + } + } + + /// Gets or sets optional display names corresponding to the enum values. + [JsonPropertyName("enumNames")] + public IList? EnumNames { get; set; } + } +} diff --git a/src/ModelContextProtocol/Protocol/ElicitResult.cs b/src/ModelContextProtocol/Protocol/ElicitResult.cs new file mode 100644 index 000000000..e45e887f3 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/ElicitResult.cs @@ -0,0 +1,41 @@ +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the client's response to an elicitation request. +/// +public class ElicitResult +{ + /// + /// Gets or sets the user action in response to the elicitation. + /// + /// + /// + /// + /// "accept" + /// User submitted the form/confirmed the action + /// + /// + /// "decline" + /// User explicitly declined the action + /// + /// + /// "cancel" + /// User dismissed without making an explicit choice + /// + /// + /// + [JsonPropertyName("action")] + public string Action { get; set; } = "cancel"; + + /// + /// Gets or sets the submitted form data. + /// + /// + /// This is typically omitted if the action is "cancel" or "decline". + /// + [JsonPropertyName("content")] + public JsonElement? Content { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/ElicitationCapability.cs b/src/ModelContextProtocol/Protocol/ElicitationCapability.cs new file mode 100644 index 000000000..7b918affb --- /dev/null +++ b/src/ModelContextProtocol/Protocol/ElicitationCapability.cs @@ -0,0 +1,36 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Represents the capability for a client to provide server-requested additional information during interactions. +/// +/// +/// +/// This capability enables the MCP client to respond to elicitation requests from an MCP server. +/// +/// +/// When this capability is enabled, an MCP server can request the client to provide additional information +/// during interactions. The client must set a to process these requests. +/// +/// +public class ElicitationCapability +{ + // Currently empty in the spec, but may be extended in the future. + + /// + /// Gets or sets the handler for processing requests. + /// + /// + /// + /// This handler function is called when an MCP server requests the client to provide additional + /// information during interactions. The client must set this property for the elicitation capability to work. + /// + /// + /// The handler receives message parameters and a cancellation token. + /// It should return a containing the response to the elicitation request. + /// + /// + [JsonIgnore] + public Func>? ElicitationHandler { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/RequestMethods.cs b/src/ModelContextProtocol/Protocol/RequestMethods.cs index 3b511ef90..2081a6049 100644 --- a/src/ModelContextProtocol/Protocol/RequestMethods.cs +++ b/src/ModelContextProtocol/Protocol/RequestMethods.cs @@ -93,6 +93,15 @@ public static class RequestMethods /// public const string SamplingCreateMessage = "sampling/createMessage"; + /// + /// The name of the request method sent from the client to the server to elicit additional information from the user via the client. + /// + /// + /// This request is used when the server needs more information from the client to proceed with a task or interaction. + /// Servers can request structured data from users, with optional JSON schemas to validate responses. + /// + public const string ElicitationCreate = "elicitation/create"; + /// /// The name of the request method sent from the client to the server when it first connects, asking it initialize. /// diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 8626d8a17..bdd6eb18d 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -20,13 +20,13 @@ public static class McpServerExtensions /// The to monitor for cancellation requests. /// A task containing the sampling result from the client. /// is . - /// The client does not support sampling. + /// The client does not support sampling. /// /// This method requires the client to support sampling capabilities. /// It allows detailed control over sampling parameters including messages, system prompt, temperature, /// and token limits. /// - public static ValueTask RequestSamplingAsync( + public static ValueTask SampleAsync( this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken = default) { Throw.IfNull(server); @@ -50,12 +50,12 @@ public static ValueTask RequestSamplingAsync( /// A task containing the chat response from the model. /// is . /// is . - /// The client does not support sampling. + /// The client does not support sampling. /// /// This method converts the provided chat messages into a format suitable for the sampling API, /// handling different content types such as text, images, and audio. /// - public static async Task RequestSamplingAsync( + public static async Task SampleAsync( this IMcpServer server, IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) { @@ -125,7 +125,7 @@ public static async Task RequestSamplingAsync( modelPreferences = new() { Hints = [new() { Name = modelId }] }; } - var result = await server.RequestSamplingAsync(new() + var result = await server.SampleAsync(new() { Messages = samplingMessages, MaxTokens = options?.MaxOutputTokens, @@ -152,7 +152,7 @@ public static async Task RequestSamplingAsync( /// The server to be wrapped as an . /// The that can be used to issue sampling requests to the client. /// is . - /// The client does not support sampling. + /// The client does not support sampling. public static IChatClient AsSamplingChatClient(this IMcpServer server) { Throw.IfNull(server); @@ -179,7 +179,7 @@ public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) /// The to monitor for cancellation requests. /// A task containing the list of roots exposed by the client. /// is . - /// The client does not support roots. + /// The client does not support roots. /// /// This method requires the client to support the roots capability. /// Root resources allow clients to expose a hierarchical structure of resources that can be @@ -200,6 +200,32 @@ public static ValueTask RequestRootsAsync( cancellationToken: cancellationToken); } + /// + /// Requests additional information from the user via the client, allowing the server to elicit structured data. + /// + /// The server initiating the request. + /// The parameters for the elicitation request. + /// The to monitor for cancellation requests. + /// A task containing the elicitation result. + /// is . + /// The client does not support elicitation. + /// + /// This method requires the client to support the elicitation capability. + /// + public static ValueTask ElicitAsync( + this IMcpServer server, ElicitRequestParams request, CancellationToken cancellationToken = default) + { + Throw.IfNull(server); + ThrowIfElicitationUnsupported(server); + + return server.SendRequestAsync( + RequestMethods.ElicitationCreate, + request, + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult, + cancellationToken: cancellationToken); + } + private static void ThrowIfSamplingUnsupported(IMcpServer server) { if (server.ClientCapabilities?.Sampling is null) @@ -226,12 +252,25 @@ private static void ThrowIfRootsUnsupported(IMcpServer server) } } + private static void ThrowIfElicitationUnsupported(IMcpServer server) + { + if (server.ClientCapabilities?.Elicitation is null) + { + if (server.ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Elicitation is not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support elicitation requests."); + } + } + /// Provides an implementation that's implemented via client sampling. private sealed class SamplingChatClient(IMcpServer server) : IChatClient { /// public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - server.RequestSamplingAsync(messages, options, cancellationToken); + server.SampleAsync(messages, options, cancellationToken); /// async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( diff --git a/tests/Common/Utils/TestServerTransport.cs b/tests/Common/Utils/TestServerTransport.cs index ec2712242..cd12504a0 100644 --- a/tests/Common/Utils/TestServerTransport.cs +++ b/tests/Common/Utils/TestServerTransport.cs @@ -39,9 +39,11 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can if (message is JsonRpcRequest request) { if (request.Method == RequestMethods.RootsList) - await ListRoots(request, cancellationToken); + await ListRootsAsync(request, cancellationToken); else if (request.Method == RequestMethods.SamplingCreateMessage) - await Sampling(request, cancellationToken); + await SamplingAsync(request, cancellationToken); + else if (request.Method == RequestMethods.ElicitationCreate) + await ElicitAsync(request, cancellationToken); else await WriteMessageAsync(request, cancellationToken); } @@ -53,7 +55,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can OnMessageSent?.Invoke(message); } - private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken) + private async Task ListRootsAsync(JsonRpcRequest request, CancellationToken cancellationToken) { await WriteMessageAsync(new JsonRpcResponse { @@ -65,7 +67,7 @@ await WriteMessageAsync(new JsonRpcResponse }, cancellationToken); } - private async Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken) + private async Task SamplingAsync(JsonRpcRequest request, CancellationToken cancellationToken) { await WriteMessageAsync(new JsonRpcResponse { @@ -74,6 +76,15 @@ await WriteMessageAsync(new JsonRpcResponse }, cancellationToken); } + private async Task ElicitAsync(JsonRpcRequest request, CancellationToken cancellationToken) + { + await WriteMessageAsync(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new ElicitResult { Action = "decline" }, McpJsonUtilities.DefaultOptions), + }, cancellationToken); + } + private async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { await _messageChannel.Writer.WriteAsync(message, cancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index acfc744b9..8786da265 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -136,6 +136,26 @@ public async Task RootsRequest_Fails_WithInvalidOperationException() Assert.Equal("Server to client requests are not supported in stateless mode.", toolContent.Text); } + [Fact] + public async Task ElicitRequest_Fails_WithInvalidOperationException() + { + await StartAsync(); + + var mcpClientOptions = new McpClientOptions(); + mcpClientOptions.Capabilities = new(); + mcpClientOptions.Capabilities.Elicitation ??= new(); + mcpClientOptions.Capabilities.Elicitation.ElicitationHandler = (_, _) => + { + throw new UnreachableException(); + }; + + await using var client = await ConnectMcpClientAsync(mcpClientOptions); + + var toolResponse = await client.CallToolAsync("testElicitationErrors", cancellationToken: TestContext.Current.CancellationToken); + var toolContent = Assert.Single(toolResponse.Content); + Assert.Equal("Server to client requests are not supported in stateless mode.", toolContent.Text); + } + [Fact] public async Task UnsolicitedNotification_Fails_WithInvalidOperationException() { @@ -184,7 +204,7 @@ public static async Task TestSamplingErrors(IMcpServer server) var asSamplingChatClientEx = Assert.Throws(() => server.AsSamplingChatClient()); Assert.Equal(expectedSamplingErrorMessage, asSamplingChatClientEx.Message); - var requestSamplingEx = await Assert.ThrowsAsync(() => server.RequestSamplingAsync([])); + var requestSamplingEx = await Assert.ThrowsAsync(() => server.SampleAsync([])); Assert.Equal(expectedSamplingErrorMessage, requestSamplingEx.Message); var ex = await Assert.ThrowsAsync(() => server.SendRequestAsync(new JsonRpcRequest { Method = RequestMethods.SamplingCreateMessage })); @@ -206,6 +226,21 @@ public static async Task TestRootsErrors(IMcpServer server) return ex.Message; } + [McpServerTool(Name = "testElicitationErrors")] + public static async Task TestElicitationErrors(IMcpServer server) + { + const string expectedElicitationErrorMessage = "Elicitation is not supported in stateless mode."; + + // Even when the client has elicitation support, it should not be advertised in stateless mode. + Assert.Null(server.ClientCapabilities); + + var requestElicitationEx = Assert.Throws(() => server.ElicitAsync(new())); + Assert.Equal(expectedElicitationErrorMessage, requestElicitationEx.Message); + + var ex = await Assert.ThrowsAsync(() => server.SendRequestAsync(new JsonRpcRequest { Method = RequestMethods.ElicitationCreate })); + return ex.Message; + } + [McpServerTool(Name = "testScope")] public static string? TestScope(ScopedService scopedService) => scopedService.State; diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 10f2aec78..4b9d64de0 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -179,7 +179,7 @@ private static ToolsCapability ConfigureTools() { throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } - var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), + var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), cancellationToken); return new CallToolResponse() diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 583d68780..2eb31c21a 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -177,7 +177,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } - var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), + var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), cancellationToken); return new CallToolResponse() diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 2fd52cb9b..75966b6eb 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -92,21 +92,21 @@ public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Run } [Fact] - public async Task RequestSamplingAsync_Should_Throw_Exception_If_Client_Does_Not_Support_Sampling() + public async Task SampleAsync_Should_Throw_Exception_If_Client_Does_Not_Support_Sampling() { // Arrange await using var transport = new TestServerTransport(); await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); - var action = async () => await server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); + var action = async () => await server.SampleAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); // Act & Assert await Assert.ThrowsAsync(action); } [Fact] - public async Task RequestSamplingAsync_Should_SendRequest() + public async Task SampleAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); @@ -116,7 +116,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act - var result = await server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); + var result = await server.SampleAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); Assert.NotNull(result); Assert.NotEmpty(transport.SentMessages); @@ -161,6 +161,40 @@ public async Task RequestRootsAsync_Should_SendRequest() await runTask; } + [Fact] + public async Task ElicitAsync_Should_Throw_Exception_If_Client_Does_Not_Support_Elicitation() + { + // Arrange + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + SetClientCapabilities(server, new ClientCapabilities()); + + // Act & Assert + await Assert.ThrowsAsync(async () => await server.ElicitAsync(new ElicitRequestParams(), CancellationToken.None)); + } + + [Fact] + public async Task ElicitAsync_Should_SendRequest() + { + // Arrange + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + SetClientCapabilities(server, new ClientCapabilities { Elicitation = new ElicitationCapability() }); + var runTask = server.RunAsync(TestContext.Current.CancellationToken); + + // Act + var result = await server.ElicitAsync(new ElicitRequestParams(), CancellationToken.None); + + // Assert + Assert.NotNull(result); + Assert.NotEmpty(transport.SentMessages); + Assert.IsType(transport.SentMessages[0]); + Assert.Equal(RequestMethods.ElicitationCreate, ((JsonRpcRequest)transport.SentMessages[0]).Method); + + await transport.DisposeAsync(); + await runTask; + } + [Fact] public async Task Can_Handle_Ping_Requests() {