-
Notifications
You must be signed in to change notification settings - Fork 698
Expand file tree
/
Copy pathResumabilityIntegrationTestsBase.cs
More file actions
562 lines (470 loc) · 23.7 KB
/
ResumabilityIntegrationTestsBase.cs
File metadata and controls
562 lines (470 loc) · 23.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
using System.ComponentModel;
using System.Diagnostics;
using System.Net;
using System.Net.ServerSentEvents;
using System.Text;
using System.Text.Json.Nodes;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.AspNetCore.Tests.Utils;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
namespace ModelContextProtocol.AspNetCore.Tests;
/// <summary>
/// Base class for SSE resumability integration tests that can be run against different
/// <see cref="ISseEventStreamStore"/> implementations.
/// </summary>
/// <remarks>
/// <para>
/// Tests in this class verify resumability behavior without relying on implementation-specific
/// internals of the event store. Derived classes can override virtual tests to add additional
/// assertions specific to their event store implementation.
/// </para>
/// <para>
/// The <see cref="CreateEventStreamStoreAsync"/> method must be implemented by derived classes
/// to provide the specific event store implementation to test.
/// </para>
/// </remarks>
public abstract class ResumabilityIntegrationTestsBase(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper)
{
/// <summary>
/// The initialize request JSON for the current protocol version.
/// </summary>
protected const string InitializeRequest = """
{"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"TestClient","version":"1.0.0"}}}
""";
/// <summary>
/// Gets the event stream store created for the current test.
/// </summary>
/// <remarks>
/// This is set after <see cref="CreateServerAsync"/> is called.
/// </remarks>
protected ISseEventStreamStore? EventStreamStore { get; private set; }
/// <summary>
/// Creates the event stream store implementation to use for this test.
/// </summary>
/// <returns>The event stream store instance.</returns>
protected abstract ValueTask<ISseEventStreamStore> CreateEventStreamStoreAsync();
[Fact]
public virtual async Task Server_StoresEvents_WhenEventStoreConfigured()
{
// Arrange
await using var app = await CreateServerAsync();
await using var client = await ConnectClientAsync();
// Act - Make a tool call which generates events
var result = await client.CallToolAsync("echo",
new Dictionary<string, object?> { ["message"] = "test" },
cancellationToken: TestContext.Current.CancellationToken);
// Assert - The call succeeded
Assert.NotNull(result);
var textContent = Assert.Single(result.Content.OfType<TextContentBlock>());
Assert.Equal("Echo: test", textContent.Text);
}
[Fact]
public virtual async Task Client_CanMakeMultipleRequests_WithResumabilityEnabled()
{
// Arrange
await using var app = await CreateServerAsync();
await using var client = await ConnectClientAsync();
// Act - Make many requests to verify stability
for (int i = 0; i < 5; i++)
{
var result = await client.CallToolAsync("echo",
new Dictionary<string, object?> { ["message"] = $"test{i}" },
cancellationToken: TestContext.Current.CancellationToken);
var textContent = Assert.Single(result.Content.OfType<TextContentBlock>());
Assert.Equal($"Echo: test{i}", textContent.Text);
}
}
[Fact]
public virtual async Task Ping_WorksWithResumabilityEnabled()
{
// Arrange
await using var app = await CreateServerAsync();
await using var client = await ConnectClientAsync();
// Act & Assert - Ping should work
await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken);
}
[Fact]
public virtual async Task ListTools_WorksWithResumabilityEnabled()
{
// Arrange
await using var app = await CreateServerAsync();
await using var client = await ConnectClientAsync();
// Act
var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
// Assert
Assert.NotNull(tools);
Assert.Single(tools);
}
[Fact]
public virtual async Task Client_CanPollResponse_FromServer()
{
const string ProgressToolName = "progress_tool";
var clientReceivedInitialValueTcs = new TaskCompletionSource();
var clientReceivedPolledValueTcs = new TaskCompletionSource();
var progressTool = McpServerTool.Create(async (RequestContext<CallToolRequestParams> context, IProgress<ProgressNotificationValue> progress) =>
{
progress.Report(new() { Progress = 0, Message = "Initial value" });
await clientReceivedInitialValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1));
progress.Report(new() { Progress = 50, Message = "Polled value" });
await clientReceivedPolledValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
return "Complete";
}, options: new() { Name = ProgressToolName });
await using var app = await CreateServerAsync(configureServer: builder =>
{
builder.WithTools([progressTool]);
});
await using var client = await ConnectClientAsync();
var progressHandler = new Progress<ProgressNotificationValue>(value =>
{
switch (value.Message)
{
case "Initial value":
Assert.True(clientReceivedInitialValueTcs.TrySetResult(), "Received the initial value more than once.");
break;
case "Polled value":
Assert.True(clientReceivedPolledValueTcs.TrySetResult(), "Received the polled value more than once.");
break;
default:
throw new UnreachableException($"Unknown progress message '{value.Message}'");
}
});
var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken);
Assert.False(result.IsError is true);
Assert.Equal("Complete", result.Content.OfType<TextContentBlock>().Single().Text);
}
[Fact]
public virtual async Task Client_CanResumePostResponseStream_AfterDisconnection()
{
using var faultingStreamHandler = new FaultingStreamHandler()
{
InnerHandler = SocketsHttpHandler,
};
HttpClient = new(faultingStreamHandler);
ConfigureHttpClient(HttpClient);
const string ProgressToolName = "progress_tool";
const string InitialMessage = "Initial notification";
const string ReplayedMessage = "Replayed notification";
const string ResultMessage = "Complete";
var clientReceivedInitialValueTcs = new TaskCompletionSource();
var clientReceivedReconnectValueTcs = new TaskCompletionSource();
var progressTool = McpServerTool.Create(async (RequestContext<CallToolRequestParams> context, IProgress<ProgressNotificationValue> progress, CancellationToken cancellationToken) =>
{
progress.Report(new() { Progress = 0, Message = InitialMessage });
// Make sure the client receives one message before we disconnect.
await clientReceivedInitialValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
// Simulate a network disconnection by faulting the response stream.
var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken);
// Send another message that the client should receive after reconnecting.
progress.Report(new() { Progress = 50, Message = ReplayedMessage });
reconnectAttempt.Continue();
// Wait for the client to receive the message via replay.
await clientReceivedReconnectValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
// Return the final result with the client still connected.
return ResultMessage;
}, options: new() { Name = ProgressToolName });
await using var app = await CreateServerAsync(configureServer: builder =>
{
builder.WithTools([progressTool]);
});
await using var client = await ConnectClientAsync();
var initialNotificationReceivedCount = 0;
var replayedNotificationReceivedCount = 0;
var progressHandler = new Progress<ProgressNotificationValue>(value =>
{
switch (value.Message)
{
case InitialMessage:
initialNotificationReceivedCount++;
clientReceivedInitialValueTcs.TrySetResult();
break;
case ReplayedMessage:
replayedNotificationReceivedCount++;
clientReceivedReconnectValueTcs.TrySetResult();
break;
default:
throw new UnreachableException($"Unknown progress message '{value.Message}'");
}
});
var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken);
Assert.False(result.IsError is true);
Assert.Equal(1, initialNotificationReceivedCount);
Assert.Equal(1, replayedNotificationReceivedCount);
Assert.Equal(ResultMessage, result.Content.OfType<TextContentBlock>().Single().Text);
}
[Fact]
public virtual async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
{
var timeout = TestConstants.DefaultTimeout;
using var faultingStreamHandler = new FaultingStreamHandler()
{
InnerHandler = SocketsHttpHandler,
};
HttpClient = new(faultingStreamHandler);
ConfigureHttpClient(HttpClient);
// Capture the server instance via RunSessionHandler
var serverTcs = new TaskCompletionSource<McpServer>();
await using var app = await CreateServerAsync(configureTransport: options =>
{
#pragma warning disable MCPEXP002 // RunSessionHandler is experimental
options.RunSessionHandler = (httpContext, mcpServer, cancellationToken) =>
{
serverTcs.TrySetResult(mcpServer);
return mcpServer.RunAsync(cancellationToken);
};
#pragma warning restore MCPEXP002
});
await using var client = await ConnectClientAsync();
// Get the server instance
var server = await serverTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
// Set up notification tracking with unique messages
var clientReceivedInitialNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
const string CustomNotificationMethod = "test/custom_notification";
const string InitialMessage = "Initial notification";
const string ReplayedMessage = "Replayed notification";
const string ReconnectMessage = "Reconnect notification";
var initialNotificationReceivedCount = 0;
var replayedNotificationReceivedCount = 0;
var reconnectNotificationReceivedCount = 0;
await using var _ = client.RegisterNotificationHandler(CustomNotificationMethod, (notification, cancellationToken) =>
{
var message = notification.Params?["message"]?.GetValue<string>();
switch (message)
{
case InitialMessage:
initialNotificationReceivedCount++;
clientReceivedInitialNotificationTcs.TrySetResult();
break;
case ReplayedMessage:
replayedNotificationReceivedCount++;
clientReceivedReplayedNotificationTcs.TrySetResult();
break;
case ReconnectMessage:
reconnectNotificationReceivedCount++;
clientReceivedReconnectNotificationTcs.TrySetResult();
break;
default:
throw new UnreachableException($"Unknown notification message '{message}'");
}
return default;
});
// Wait for the client's unsolicited message stream to be established before sending notifications
await faultingStreamHandler.WaitForUnsolicitedMessageStreamAsync(TestContext.Current.CancellationToken);
// Send a custom notification to the client on the unsolicited message stream
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken);
// Wait for client to receive the first notification
await clientReceivedInitialNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
// Fault the unsolicited message stream (GET SSE)
var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken);
// Send another notification while the client is disconnected - this should be stored
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReplayedMessage }, cancellationToken: TestContext.Current.CancellationToken);
// Allow the client to reconnect
reconnectAttempt.Continue();
// Wait for client to receive the notification via replay
await clientReceivedReplayedNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
// Send a final notification while the client has reconnected - this should be handled by the transport
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReconnectMessage }, cancellationToken: TestContext.Current.CancellationToken);
// Wait for the client to receive the final notification
await clientReceivedReconnectNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
// Assert each notification was received exactly once
Assert.Equal(1, initialNotificationReceivedCount);
Assert.Equal(1, replayedNotificationReceivedCount);
Assert.Equal(1, reconnectNotificationReceivedCount);
}
[Fact]
public virtual async Task Server_Returns400_WhenLastEventIdRefersToWrongSession()
{
// Arrange - Create server with event store
await using var app = await CreateServerAsync();
// First, initialize a session and make a call to generate some events
using var initRequest = new HttpRequestMessage(HttpMethod.Post, "/")
{
Headers =
{
Accept = { new("application/json"), new("text/event-stream") }
},
Content = new StringContent(InitializeRequest, Encoding.UTF8, "application/json"),
};
var initResponse = await HttpClient.SendAsync(initRequest, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
initResponse.EnsureSuccessStatusCode();
// Get the session ID from the response
var sessionId = initResponse.Headers.GetValues("Mcp-Session-Id").First();
// Read the SSE response to get an event ID
await using var initStream = await initResponse.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken);
string? eventId = null;
await foreach (var sseItem in SseParser.Create(initStream).EnumerateAsync(TestContext.Current.CancellationToken))
{
if (!string.IsNullOrEmpty(sseItem.EventId))
{
eventId = sseItem.EventId;
}
}
Assert.NotNull(eventId);
// Act - Try to resume with a different session ID but the same event ID
var wrongSessionId = "wrong-session-id";
using var resumeRequest = new HttpRequestMessage(HttpMethod.Get, "/")
{
Headers =
{
Accept = { new("text/event-stream") },
}
};
resumeRequest.Headers.Add("Mcp-Session-Id", wrongSessionId);
resumeRequest.Headers.Add("Last-Event-ID", eventId);
var resumeResponse = await HttpClient.SendAsync(resumeRequest, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
// Assert - First we get 404 because the wrong session doesn't exist
Assert.Equal(HttpStatusCode.NotFound, resumeResponse.StatusCode);
// Now test with an existing session but event ID from a different session
// Create a second session
using var initRequest2 = new HttpRequestMessage(HttpMethod.Post, "/")
{
Headers =
{
Accept = { new("application/json"), new("text/event-stream") }
},
Content = new StringContent(InitializeRequest, Encoding.UTF8, "application/json"),
};
var initResponse2 = await HttpClient.SendAsync(initRequest2, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
initResponse2.EnsureSuccessStatusCode();
var sessionId2 = initResponse2.Headers.GetValues("Mcp-Session-Id").First();
Assert.NotEqual(sessionId, sessionId2);
// Read the second session's response
await using var initStream2 = await initResponse2.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken);
await foreach (var _ in SseParser.Create(initStream2).EnumerateAsync(TestContext.Current.CancellationToken))
{
// Consume the stream
}
// Try to use session 2's ID but with an event ID from session 1
using var mismatchRequest = new HttpRequestMessage(HttpMethod.Get, "/")
{
Headers =
{
Accept = { new("text/event-stream") },
}
};
mismatchRequest.Headers.Add("Mcp-Session-Id", sessionId2);
mismatchRequest.Headers.Add("Last-Event-ID", eventId); // This event ID belongs to session 1
var mismatchResponse = await HttpClient.SendAsync(mismatchRequest, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
// Assert - Should get 400 Bad Request because the event ID doesn't match the session
Assert.Equal(HttpStatusCode.BadRequest, mismatchResponse.StatusCode);
// Verify the error message
var responseBody = await mismatchResponse.Content.ReadAsStringAsync(TestContext.Current.CancellationToken);
var errorResponse = JsonNode.Parse(responseBody);
Assert.NotNull(errorResponse);
var errorMessage = errorResponse["error"]?["message"]?.GetValue<string>();
Assert.Equal("Bad Request: The Last-Event-ID header refers to a session with a different session ID.", errorMessage);
}
[Fact]
public virtual async Task EnablePollingAsync_SendsSseItemWithRetryField()
{
// Arrange
const string PollingToolName = "polling_tool";
var expectedRetryInterval = TimeSpan.FromSeconds(5);
var pollingTool = McpServerTool.Create(async (RequestContext<CallToolRequestParams> context) =>
{
await context.EnablePollingAsync(retryInterval: expectedRetryInterval);
return "Polling enabled";
}, options: new() { Name = PollingToolName });
await using var app = await CreateServerAsync(configureServer: builder =>
{
builder.WithTools([pollingTool]);
});
await using var client = await ConnectClientAsync();
// Act - Call the tool that enables polling
var result = await client.CallToolAsync(PollingToolName, cancellationToken: TestContext.Current.CancellationToken);
// Assert - The result should be successful
Assert.False(result.IsError is true);
Assert.Equal("Polling enabled", result.Content.OfType<TextContentBlock>().Single().Text);
}
[McpServerToolType]
protected class ResumabilityTestTools
{
[McpServerTool(Name = "echo"), Description("Echoes the message back")]
public static string Echo(string message) => $"Echo: {message}";
}
/// <summary>
/// Creates a server with the event stream store from <see cref="CreateEventStreamStoreAsync"/>.
/// </summary>
protected async Task<WebApplication> CreateServerAsync(
Action<IMcpServerBuilder>? configureServer = null,
Action<HttpServerTransportOptions>? configureTransport = null)
{
EventStreamStore = await CreateEventStreamStoreAsync();
return await CreateServerAsync(EventStreamStore, configureServer, configureTransport);
}
/// <summary>
/// Creates a server with the specified event stream store.
/// </summary>
protected async Task<WebApplication> CreateServerAsync(
ISseEventStreamStore? eventStreamStore,
Action<IMcpServerBuilder>? configureServer = null,
Action<HttpServerTransportOptions>? configureTransport = null)
{
var serverBuilder = Builder.Services.AddMcpServer()
.WithHttpTransport(options =>
{
options.EventStreamStore = eventStreamStore;
configureTransport?.Invoke(options);
})
.WithTools<ResumabilityTestTools>();
configureServer?.Invoke(serverBuilder);
var app = Builder.Build();
app.MapMcp();
await app.StartAsync(TestContext.Current.CancellationToken);
return app;
}
/// <summary>
/// Connects a client to the server.
/// </summary>
protected async Task<McpClient> ConnectClientAsync()
{
var transport = new HttpClientTransport(new HttpClientTransportOptions
{
Endpoint = new Uri("http://localhost:5000/"),
TransportMode = HttpTransportMode.StreamableHttp,
}, HttpClient, LoggerFactory);
return await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);
}
/// <summary>
/// Sends an initialize request and reads the SSE response.
/// </summary>
protected async Task<SseResponse> SendInitializeAndReadSseResponseAsync(string initializeRequest)
{
using var requestContent = new StringContent(initializeRequest, Encoding.UTF8, "application/json");
using var request = new HttpRequestMessage(HttpMethod.Post, "/")
{
Headers =
{
Accept = { new("application/json"), new("text/event-stream") }
},
Content = requestContent,
};
var response = await HttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead,
TestContext.Current.CancellationToken);
response.EnsureSuccessStatusCode();
var sseResponse = new SseResponse();
await using var stream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken);
await foreach (var sseItem in SseParser.Create(stream).EnumerateAsync(TestContext.Current.CancellationToken))
{
if (!string.IsNullOrEmpty(sseItem.EventId))
{
sseResponse.LastEventId = sseItem.EventId;
}
}
return sseResponse;
}
/// <summary>
/// Response data from an SSE stream.
/// </summary>
protected struct SseResponse
{
public string? LastEventId { get; set; }
}
}