forked from modelcontextprotocol/csharp-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMcpEndpoint.cs
More file actions
144 lines (117 loc) · 5.16 KB
/
McpEndpoint.cs
File metadata and controls
144 lines (117 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
namespace ModelContextProtocol;
/// <summary>
/// Base class for an MCP JSON-RPC endpoint. This covers both MCP clients and servers.
/// It is not supported, nor necessary, to implement both client and server functionality in the same class.
/// If an application needs to act as both a client and a server, it should use separate objects for each.
/// This is especially true as a client represents a connection to one and only one server, and vice versa.
/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction.
/// </summary>
internal abstract partial class McpEndpoint : IAsyncDisposable
{
/// <summary>Cached naming information used for name/version when none is specified.</summary>
internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
private McpSession? _session;
private CancellationTokenSource? _sessionCts;
private readonly SemaphoreSlim _disposeLock = new(1, 1);
private bool _disposed;
protected readonly ILogger _logger;
/// <summary>
/// Initializes a new instance of the <see cref="McpEndpoint"/> class.
/// </summary>
/// <param name="loggerFactory">The logger factory.</param>
protected McpEndpoint(ILoggerFactory? loggerFactory = null)
{
_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
}
protected RequestHandlers RequestHandlers { get; } = [];
protected NotificationHandlers NotificationHandlers { get; } = new();
public Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
=> GetSessionOrThrow().SendRequestAsync(request, cancellationToken);
public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
=> GetSessionOrThrow().SendMessageAsync(message, cancellationToken);
public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, ValueTask> handler) =>
GetSessionOrThrow().RegisterNotificationHandler(method, handler);
/// <summary>
/// Gets the name of the endpoint for logging and debug purposes.
/// </summary>
public abstract string EndpointName { get; }
/// <summary>
/// Task that processes incoming messages from the transport.
/// </summary>
protected Task? MessageProcessingTask { get; private set; }
protected void InitializeSession(ITransport sessionTransport)
{
_session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger);
}
[MemberNotNull(nameof(MessageProcessingTask))]
protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken)
{
_sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken);
MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token);
}
protected void CancelSession() => _sessionCts?.Cancel();
public async ValueTask DisposeAsync()
{
using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
if (_disposed)
{
return;
}
_disposed = true;
await DisposeUnsynchronizedAsync().ConfigureAwait(false);
}
/// <summary>
/// Cleans up the endpoint and releases resources.
/// </summary>
/// <returns></returns>
public virtual async ValueTask DisposeUnsynchronizedAsync()
{
LogEndpointShuttingDown(EndpointName);
try
{
if (_sessionCts is not null)
{
await _sessionCts.CancelAsync().ConfigureAwait(false);
}
if (MessageProcessingTask is not null)
{
try
{
await MessageProcessingTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
// Ignore cancellation
}
}
}
finally
{
_session?.Dispose();
_sessionCts?.Dispose();
}
LogEndpointShutDown(EndpointName);
}
protected McpSession GetSessionOrThrow()
{
#if NET
ObjectDisposedException.ThrowIf(_disposed, this);
#else
if (_disposed)
{
throw new ObjectDisposedException(GetType().Name);
}
#endif
return _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages.");
}
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shutting down.")]
private partial void LogEndpointShuttingDown(string endpointName);
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shut down.")]
private partial void LogEndpointShutDown(string endpointName);
}