Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/ModelContextProtocol.Core/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ namespace ModelContextProtocol.Client;
/// </summary>
public abstract partial class McpClient : McpSession
{
/// <summary>Initializes a new instance of the <see cref="McpClient"/> class.</summary>
private protected McpClient()
{
}

/// <summary>
/// Gets the capabilities supported by the connected server.
/// </summary>
Expand Down
5 changes: 5 additions & 0 deletions src/ModelContextProtocol.Core/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ namespace ModelContextProtocol.Server;
/// </summary>
public abstract partial class McpServer : McpSession
{
/// <summary>Initializes a new instance of the <see cref="McpServer"/> class.</summary>
private protected McpServer()
{
}

/// <summary>
/// Gets the capabilities supported by the client.
/// </summary>
Expand Down
9 changes: 8 additions & 1 deletion tests/Common/Utils/TestServerTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ await WriteMessageAsync(new JsonRpcResponse
else
{
// Return a normal sampling response
var result = MockSamplingResult ?? new CreateMessageResult { Content = [new TextContentBlock { Text = "" }], Model = "model" };
await WriteMessageAsync(new JsonRpcResponse
{
Id = request.Id,
Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = [new TextContentBlock { Text = "" }], Model = "model" }, McpJsonUtilities.DefaultOptions),
Result = JsonSerializer.SerializeToNode(result, McpJsonUtilities.DefaultOptions),
}, cancellationToken);
}
}
Expand Down Expand Up @@ -125,6 +126,12 @@ await WriteMessageAsync(new JsonRpcResponse
}
}

/// <summary>
/// Gets or sets the sampling result to return from sampling/createMessage requests.
/// When null, a default sampling response is returned.
/// </summary>
public CreateMessageResult? MockSamplingResult { get; set; }

/// <summary>
/// Gets or sets the task to return from tasks/get requests.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Moq;
using ModelContextProtocol.Tests.Utils;
using System.Collections;
using System.ComponentModel;
using System.Text.Json;
Expand Down Expand Up @@ -314,7 +314,7 @@ public async Task WithPrompts_TargetInstance_UsesTarget()
sc.AddMcpServer().WithPrompts(target);

McpServerPrompt prompt = sc.BuildServiceProvider().GetServices<McpServerPrompt>().First(t => t.ProtocolPrompt.Name == "returns_string");
var result = await prompt.GetAsync(new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") })
var result = await prompt.GetAsync(new RequestContext<GetPromptRequestParams>(McpServer.Create(new TestServerTransport(), new McpServerOptions()), new JsonRpcRequest { Method = "test", Id = new RequestId("1") })
{
Params = new GetPromptRequestParams
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Moq;
using ModelContextProtocol.Tests.Utils;
using System.Collections;
using System.ComponentModel;
using System.Threading.Channels;
Expand Down Expand Up @@ -345,7 +345,7 @@ public async Task WithResources_TargetInstance_UsesTarget()
sc.AddMcpServer().WithResources(target);

McpServerResource resource = sc.BuildServiceProvider().GetServices<McpServerResource>().First(t => t.ProtocolResource?.Name == "returns_string");
var result = await resource.ReadAsync(new RequestContext<ReadResourceRequestParams>(new Mock<McpServer>().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") })
var result = await resource.ReadAsync(new RequestContext<ReadResourceRequestParams>(McpServer.Create(new TestServerTransport(), new McpServerOptions()), new JsonRpcRequest { Method = "test", Id = new RequestId("1") })
{
Params = new()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Moq;
using ModelContextProtocol.Tests.Utils;
using System.Collections;
using System.Collections.Concurrent;
using System.ComponentModel;
Expand Down Expand Up @@ -594,7 +594,7 @@ public async Task WithTools_TargetInstance_UsesTarget()
sc.AddMcpServer().WithTools(target, BuilderToolsJsonContext.Default.Options);

McpServerTool tool = sc.BuildServiceProvider().GetServices<McpServerTool>().First(t => t.ProtocolTool.Name == "get_ctor_parameter");
var result = await tool.InvokeAsync(new RequestContext<CallToolRequestParams>(new Mock<McpServer>().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }), TestContext.Current.CancellationToken);
var result = await tool.InvokeAsync(new RequestContext<CallToolRequestParams>(McpServer.Create(new TestServerTransport(), new McpServerOptions()), new JsonRpcRequest { Method = "test", Id = new RequestId("1") }), TestContext.Current.CancellationToken);

Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text);
}
Expand Down
44 changes: 23 additions & 21 deletions tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Moq;
using ModelContextProtocol.Tests.Utils;
using System.ComponentModel;
using System.Reflection;
using System.Runtime.InteropServices;
Expand All @@ -13,6 +13,9 @@ namespace ModelContextProtocol.Tests.Server;

public class McpServerPromptTests
{
private static McpServer CreateTestServer(IServiceProvider? services = null) =>
McpServer.Create(new TestServerTransport(), new McpServerOptions(), serviceProvider: services);

private static JsonRpcRequest CreateTestJsonRpcRequest()
{
return new JsonRpcRequest
Expand Down Expand Up @@ -43,18 +46,18 @@ public void Create_InvalidArgs_Throws()
[Fact]
public async Task SupportsMcpServer()
{
Mock<McpServer> mockServer = new();
McpServer testServer = CreateTestServer();

McpServerPrompt prompt = McpServerPrompt.Create((McpServer server) =>
{
Assert.Same(mockServer.Object, server);
Assert.Same(testServer, server);
return new ChatMessage(ChatRole.User, "Hello");
});

Assert.DoesNotContain("server", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []);

var result = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(mockServer.Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(testServer, CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);
Assert.NotNull(result);
Assert.NotNull(result.Messages);
Expand All @@ -71,8 +74,7 @@ public async Task SupportsCtorInjection()
sc.AddSingleton(expectedMyService);
IServiceProvider services = sc.BuildServiceProvider();

Mock<McpServer> mockServer = new();
mockServer.SetupGet(s => s.Services).Returns(services);
McpServer server = CreateTestServer(services);

MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestPrompt));
Assert.NotNull(testMethod);
Expand All @@ -83,7 +85,7 @@ public async Task SupportsCtorInjection()
}, new() { Services = services });

var result = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(mockServer.Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(server, CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);
Assert.NotNull(result);
Assert.NotNull(result.Messages);
Expand Down Expand Up @@ -133,11 +135,11 @@ public async Task SupportsServiceFromDI()
Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []);

await Assert.ThrowsAnyAsync<ArgumentException>(async () => await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken));

var result = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()) { Services = services },
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()) { Services = services },
TestContext.Current.CancellationToken);
Assert.Equal("Hello", Assert.IsType<TextContentBlock>(result.Messages[0].Content).Text);
}
Expand All @@ -158,7 +160,7 @@ public async Task SupportsOptionalServiceFromDI()
}, new() { Services = services });

var result = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);
Assert.Equal("Hello", Assert.IsType<TextContentBlock>(result.Messages[0].Content).Text);
}
Expand All @@ -171,7 +173,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets()
_ => new DisposablePromptType());

var result = await prompt1.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);
Assert.Equal("disposals:1", Assert.IsType<TextContentBlock>(result.Messages[0].Content).Text);
}
Expand All @@ -184,7 +186,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets()
_ => new AsyncDisposablePromptType());

var result = await prompt1.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);
Assert.Equal("asyncDisposals:1", Assert.IsType<TextContentBlock>(result.Messages[0].Content).Text);
}
Expand All @@ -197,7 +199,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable
_ => new AsyncDisposableAndDisposablePromptType());

var result = await prompt1.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);
Assert.Equal("disposals:0, asyncDisposals:1", Assert.IsType<TextContentBlock>(result.Messages[0].Content).Text);
}
Expand All @@ -213,7 +215,7 @@ public async Task CanReturnGetPromptResult()
});

var actual = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);

Assert.Same(expected, actual);
Expand All @@ -230,7 +232,7 @@ public async Task CanReturnText()
});

var actual = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);

Assert.NotNull(actual);
Expand All @@ -256,7 +258,7 @@ public async Task CanReturnPromptMessage()
});

var actual = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);

Assert.NotNull(actual);
Expand Down Expand Up @@ -288,7 +290,7 @@ public async Task CanReturnPromptMessages()
});

var actual = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);

Assert.NotNull(actual);
Expand All @@ -315,7 +317,7 @@ public async Task CanReturnChatMessage()
});

var actual = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);

Assert.NotNull(actual);
Expand Down Expand Up @@ -347,7 +349,7 @@ public async Task CanReturnChatMessages()
});

var actual = await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken);

Assert.NotNull(actual);
Expand All @@ -368,7 +370,7 @@ public async Task ThrowsForNullReturn()
});

await Assert.ThrowsAsync<InvalidOperationException>(async () => await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken));
}

Expand All @@ -381,7 +383,7 @@ public async Task ThrowsForUnexpectedTypeReturn()
});

await Assert.ThrowsAsync<InvalidOperationException>(async () => await prompt.GetAsync(
new RequestContext<GetPromptRequestParams>(new Mock<McpServer>().Object, CreateTestJsonRpcRequest()),
new RequestContext<GetPromptRequestParams>(CreateTestServer(), CreateTestJsonRpcRequest()),
TestContext.Current.CancellationToken));
}

Expand Down
Loading