Skip to content

Commit 441c895

Browse files
committed
Remove Channel
1 parent 1fe273d commit 441c895

2 files changed

Lines changed: 44 additions & 94 deletions

File tree

src/ModelContextProtocol.Core/Server/McpServerImpl.cs

Lines changed: 20 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
using System.Text.Json;
88
using System.Text.Json.Nodes;
99
using System.Text.Json.Serialization.Metadata;
10-
using System.Threading.Channels;
1110

1211
namespace ModelContextProtocol.Server;
1312

@@ -211,10 +210,7 @@ public override async ValueTask DisposeAsync()
211210
{
212211
if (_mrtrContinuations.TryRemove(kvp.Key, out var continuation))
213212
{
214-
foreach (var exchange in continuation.PendingExchanges)
215-
{
216-
exchange.ResponseTcs.TrySetCanceled();
217-
}
213+
continuation.PendingExchange.ResponseTcs.TrySetCanceled();
218214
}
219215
}
220216

@@ -1192,19 +1188,20 @@ private void WrapHandlerWithMrtr(string method)
11921188
inputResponses = JsonSerializer.Deserialize(responsesNode, McpJsonUtilities.JsonContext.Default.IDictionaryStringInputResponse);
11931189
}
11941190

1195-
// Complete pending exchanges with the client's responses.
1196-
foreach (var exchange in continuation.PendingExchanges)
1191+
// Prepare for the next potential exchange before resuming the handler.
1192+
continuation.MrtrContext.ResetForNextExchange();
1193+
1194+
// Complete the pending exchange with the client's response.
1195+
var exchange = continuation.PendingExchange;
1196+
if (inputResponses is not null &&
1197+
inputResponses.TryGetValue(exchange.Key, out var response))
11971198
{
1198-
if (inputResponses is not null &&
1199-
inputResponses.TryGetValue(exchange.Key, out var response))
1200-
{
1201-
exchange.ResponseTcs.TrySetResult(response);
1202-
}
1203-
else
1204-
{
1205-
exchange.ResponseTcs.TrySetException(
1206-
new McpProtocolException($"Missing input response for key '{exchange.Key}'.", McpErrorCode.InvalidParams));
1207-
}
1199+
exchange.ResponseTcs.TrySetResult(response);
1200+
}
1201+
else
1202+
{
1203+
exchange.ResponseTcs.TrySetException(
1204+
new McpProtocolException($"Missing input response for key '{exchange.Key}'.", McpErrorCode.InvalidParams));
12081205
}
12091206

12101207
// Race again: handler completion vs new exchange.
@@ -1228,7 +1225,7 @@ private void WrapHandlerWithMrtr(string method)
12281225
Task<JsonNode?> handlerTask;
12291226
try
12301227
{
1231-
handlerTask = InvokeOriginalHandlerAsync(originalHandler, request, mrtrContext, cancellationToken);
1228+
handlerTask = originalHandler(request, cancellationToken);
12321229
}
12331230
finally
12341231
{
@@ -1241,31 +1238,7 @@ private void WrapHandlerWithMrtr(string method)
12411238
}
12421239

12431240
/// <summary>
1244-
/// Invokes the original request handler and marks the MrtrContext as complete when done.
1245-
/// </summary>
1246-
private static async Task<JsonNode?> InvokeOriginalHandlerAsync(
1247-
Func<JsonRpcRequest, CancellationToken, Task<JsonNode?>> handler,
1248-
JsonRpcRequest request,
1249-
MrtrContext mrtrContext,
1250-
CancellationToken cancellationToken)
1251-
{
1252-
try
1253-
{
1254-
return await handler(request, cancellationToken).ConfigureAwait(false);
1255-
}
1256-
catch (Exception ex)
1257-
{
1258-
mrtrContext.Fault(ex);
1259-
throw;
1260-
}
1261-
finally
1262-
{
1263-
mrtrContext.Complete();
1264-
}
1265-
}
1266-
1267-
/// <summary>
1268-
/// Races between handler completion and the MrtrContext exchange channel.
1241+
/// Races between handler completion and the MrtrContext exchange TCS.
12691242
/// If the handler completes, returns its result. If an exchange arrives (handler needs input),
12701243
/// builds and returns an IncompleteResult and stores the continuation for future retries.
12711244
/// </summary>
@@ -1280,10 +1253,7 @@ private void WrapHandlerWithMrtr(string method)
12801253
return await handlerTask.ConfigureAwait(false);
12811254
}
12821255

1283-
// Start reading from the exchange channel.
1284-
var readTask = mrtrContext.ExchangeReader.ReadAsync(cancellationToken).AsTask();
1285-
1286-
var completedTask = await Task.WhenAny(handlerTask, readTask).ConfigureAwait(false);
1256+
var completedTask = await Task.WhenAny(handlerTask, mrtrContext.ExchangeTask).ConfigureAwait(false);
12871257

12881258
if (completedTask == handlerTask)
12891259
{
@@ -1292,40 +1262,17 @@ private void WrapHandlerWithMrtr(string method)
12921262
}
12931263

12941264
// Exchange arrived - handler needs input from the client.
1295-
MrtrExchange firstExchange;
1296-
try
1297-
{
1298-
firstExchange = await readTask.ConfigureAwait(false);
1299-
}
1300-
catch (ChannelClosedException)
1301-
{
1302-
// Channel was closed (handler completed between WhenAny and ReadAsync).
1303-
return await handlerTask.ConfigureAwait(false);
1304-
}
1305-
1306-
// Collect all currently available exchanges (handles concurrent ElicitAsync/SampleAsync calls).
1307-
var exchanges = new List<MrtrExchange> { firstExchange };
1308-
while (mrtrContext.ExchangeReader.TryRead(out var additionalExchange))
1309-
{
1310-
exchanges.Add(additionalExchange);
1311-
}
1312-
1313-
// Build the IncompleteResult with input requests.
1314-
var inputRequests = new Dictionary<string, InputRequest>(exchanges.Count);
1315-
foreach (var exchange in exchanges)
1316-
{
1317-
inputRequests[exchange.Key] = exchange.InputRequest;
1318-
}
1265+
var exchange = await mrtrContext.ExchangeTask.ConfigureAwait(false);
13191266

13201267
var correlationId = Guid.NewGuid().ToString("N");
13211268
var incompleteResult = new IncompleteResult
13221269
{
1323-
InputRequests = inputRequests,
1270+
InputRequests = new Dictionary<string, InputRequest> { [exchange.Key] = exchange.InputRequest },
13241271
RequestState = correlationId,
13251272
};
13261273

13271274
// Store the continuation so the retry can resume the handler.
1328-
_mrtrContinuations[correlationId] = new MrtrContinuation(handlerTask, mrtrContext, exchanges);
1275+
_mrtrContinuations[correlationId] = new MrtrContinuation(handlerTask, mrtrContext, exchange);
13291276

13301277
return JsonSerializer.SerializeToNode(incompleteResult, McpJsonUtilities.JsonContext.Default.IncompleteResult);
13311278
}

src/ModelContextProtocol.Core/Server/MrtrContext.cs

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using System.Text.Json.Nodes;
2-
using System.Threading.Channels;
32
using ModelContextProtocol.Protocol;
43

54
namespace ModelContextProtocol.Server;
@@ -8,8 +7,9 @@ namespace ModelContextProtocol.Server;
87
/// Manages the MRTR (Multi Round-Trip Request) coordination between a handler and the pipeline.
98
/// When a handler calls <see cref="McpServer.ElicitAsync(ModelContextProtocol.Protocol.ElicitRequestParams, System.Threading.CancellationToken)"/> or
109
/// <see cref="McpServer.SampleAsync(ModelContextProtocol.Protocol.CreateMessageRequestParams, System.Threading.CancellationToken)"/>,
11-
/// the handler writes to the channel and suspends on a TCS. The pipeline reads from the channel,
12-
/// sends an <see cref="IncompleteResult"/>, and later completes the TCS when the retry arrives.
10+
/// the handler sets the exchange TCS and suspends on a response TCS. The pipeline detects the exchange
11+
/// via <see cref="ExchangeTask"/>, sends an <see cref="IncompleteResult"/>, and later completes the
12+
/// response TCS when the retry arrives.
1313
/// </summary>
1414
internal sealed class MrtrContext
1515
{
@@ -18,15 +18,14 @@ internal sealed class MrtrContext
1818
/// </summary>
1919
internal const string ExperimentalCapabilityKey = "mrtr";
2020

21-
private readonly Channel<MrtrExchange> _exchanges = Channel.CreateUnbounded<MrtrExchange>(
22-
new UnboundedChannelOptions { SingleReader = true });
21+
private TaskCompletionSource<MrtrExchange> _exchangeTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
2322

2423
private int _nextInputRequestId;
2524

2625
/// <summary>
27-
/// Gets the channel reader for consuming exchanges produced by the handler.
26+
/// Gets a task that completes when the handler produces an exchange (calls ElicitAsync/SampleAsync/RequestRootsAsync).
2827
/// </summary>
29-
public ChannelReader<MrtrExchange> ExchangeReader => _exchanges.Reader;
28+
public Task<MrtrExchange> ExchangeTask => _exchangeTcs.Task;
3029

3130
/// <summary>
3231
/// Called by <see cref="McpServer.ElicitAsync(ModelContextProtocol.Protocol.ElicitRequestParams, System.Threading.CancellationToken)"/>
@@ -36,26 +35,30 @@ internal sealed class MrtrContext
3635
/// <param name="inputRequest">The input request describing what the server needs.</param>
3736
/// <param name="cancellationToken">A token to cancel the wait for input.</param>
3837
/// <returns>The client's response to the input request.</returns>
38+
/// <exception cref="InvalidOperationException">A concurrent server-to-client request is already pending.</exception>
3939
public async Task<InputResponse> RequestInputAsync(InputRequest inputRequest, CancellationToken cancellationToken)
4040
{
41-
var key = $"input_{Interlocked.Increment(ref _nextInputRequestId)}";
41+
var tcs = _exchangeTcs;
42+
if (tcs.Task.IsCompleted)
43+
{
44+
throw new InvalidOperationException("Concurrent server-to-client requests are not supported. Await each ElicitAsync, SampleAsync, or RequestRootsAsync call before making another.");
45+
}
4246

47+
var key = $"input_{Interlocked.Increment(ref _nextInputRequestId)}";
4348
var exchange = new MrtrExchange(key, inputRequest);
44-
45-
await _exchanges.Writer.WriteAsync(exchange, cancellationToken).ConfigureAwait(false);
49+
tcs.TrySetResult(exchange);
4650

4751
return await exchange.ResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
4852
}
4953

5054
/// <summary>
51-
/// Signals that the handler has completed normally.
52-
/// </summary>
53-
public void Complete() => _exchanges.Writer.TryComplete();
54-
55-
/// <summary>
56-
/// Signals that the handler has faulted.
55+
/// Prepares the context for the next round of exchange after a retry arrives.
56+
/// Must be called before completing the previous exchange's response TCS.
5757
/// </summary>
58-
public void Fault(Exception exception) => _exchanges.Writer.TryComplete(exception);
58+
public void ResetForNextExchange()
59+
{
60+
_exchangeTcs = new TaskCompletionSource<MrtrExchange>(TaskCreationOptions.RunContinuationsAsynchronously);
61+
}
5962
}
6063

6164
/// <summary>
@@ -93,11 +96,11 @@ public MrtrExchange(string key, InputRequest inputRequest)
9396
/// </summary>
9497
internal sealed class MrtrContinuation
9598
{
96-
public MrtrContinuation(Task<JsonNode?> handlerTask, MrtrContext mrtrContext, IReadOnlyList<MrtrExchange> pendingExchanges)
99+
public MrtrContinuation(Task<JsonNode?> handlerTask, MrtrContext mrtrContext, MrtrExchange pendingExchange)
97100
{
98101
HandlerTask = handlerTask;
99102
MrtrContext = mrtrContext;
100-
PendingExchanges = pendingExchanges;
103+
PendingExchange = pendingExchange;
101104
}
102105

103106
/// <summary>
@@ -111,7 +114,7 @@ public MrtrContinuation(Task<JsonNode?> handlerTask, MrtrContext mrtrContext, IR
111114
public MrtrContext MrtrContext { get; }
112115

113116
/// <summary>
114-
/// The exchanges that are awaiting responses from the client.
117+
/// The exchange that is awaiting a response from the client.
115118
/// </summary>
116-
public IReadOnlyList<MrtrExchange> PendingExchanges { get; }
119+
public MrtrExchange PendingExchange { get; }
117120
}

0 commit comments

Comments
 (0)