Skip to content

Commit b157630

Browse files
committed
Test flakiness fixes
1 parent a1e9e41 commit b157630

4 files changed

Lines changed: 44 additions & 13 deletions

File tree

src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,6 @@ await WriteJsonRpcErrorAsync(context,
154154
{
155155
await using var _ = await session.AcquireReferenceAsync(cancellationToken);
156156
InitializeSseResponse(context);
157-
158-
// We should flush headers to indicate a 200 success quickly, because the initialization response
159-
// will be sent in response to a different POST request. It might be a while before we send a message
160-
// over this response body.
161-
await context.Response.Body.FlushAsync(cancellationToken);
162157
await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken);
163158
}
164159
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)

src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo
125125
var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime<JsonRpcMessage>(), cancellationToken).ConfigureAwait(false);
126126
await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false);
127127
}
128+
129+
// We should flush to indicate a 200 success quickly, because the initialization response
130+
// will be sent in response to a different POST request. It might be a while before we send a message
131+
// over this response body.
132+
await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false);
128133
}
129134

130135
// Wait for the response to be written before returning from the handler.

tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ public async Task Client_CanResumePostResponseStream_AfterDisconnection()
279279
[Fact]
280280
public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
281281
{
282+
var timeout = TimeSpan.FromSeconds(10);
282283
using var faultingStreamHandler = new FaultingStreamHandler()
283284
{
284285
InnerHandler = SocketsHttpHandler,
@@ -304,12 +305,12 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
304305
await using var client = await ConnectClientAsync();
305306

306307
// Get the server instance
307-
var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
308+
var server = await serverTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
308309

309310
// Set up notification tracking with unique messages
310-
var clientReceivedInitialNotificationTcs = new TaskCompletionSource();
311-
var clientReceivedReplayedNotificationTcs = new TaskCompletionSource();
312-
var clientReceivedReconnectNotificationTcs = new TaskCompletionSource();
311+
var clientReceivedInitialNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
312+
var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
313+
var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
313314

314315
const string CustomNotificationMethod = "test/custom_notification";
315316
const string InitialMessage = "Initial notification";
@@ -343,11 +344,14 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
343344
return default;
344345
});
345346

347+
// Wait for the client's unsolicited message stream to be established before sending notifications
348+
await faultingStreamHandler.WaitForUnsolicitedMessageStreamAsync(TestContext.Current.CancellationToken);
349+
346350
// Send a custom notification to the client on the unsolicited message stream
347351
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken);
348352

349353
// Wait for client to receive the first notification
350-
await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
354+
await clientReceivedInitialNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
351355

352356
// Fault the unsolicited message stream (GET SSE)
353357
var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken);
@@ -359,13 +363,13 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
359363
reconnectAttempt.Continue();
360364

361365
// Wait for client to receive the notification via replay
362-
await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
366+
await clientReceivedReplayedNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
363367

364368
// Send a final notification while the client has reconnected - this should be handled by the transport
365369
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReconnectMessage }, cancellationToken: TestContext.Current.CancellationToken);
366370

367371
// Wait for the client to receive the final notification
368-
await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
372+
await clientReceivedReconnectNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);
369373

370374
// Assert each notification was received exactly once
371375
Assert.Equal(1, initialNotificationReceivedCount);
@@ -531,7 +535,7 @@ public async Task PostResponse_EndsAndSseEventStreamWriterIsDisposed_WhenWriteEv
531535
timeoutCts.CancelAfter(TimeSpan.FromSeconds(10));
532536

533537
// The call task should throw an OCE due to cancellation
534-
await Assert.ThrowsAsync<OperationCanceledException>(() => callTask).WaitAsync(timeoutCts.Token);
538+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => callTask).WaitAsync(timeoutCts.Token);
535539

536540
// Wait for the writer to be disposed
537541
await blockingStore.DisposedTask.WaitAsync(timeoutCts.Token);

tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ internal sealed class FaultingStreamHandler : DelegatingHandler
1111
{
1212
private FaultingStream? _lastStream;
1313
private TaskCompletionSource? _reconnectTcs;
14+
private TaskCompletionSource _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
15+
16+
public Task WaitForUnsolicitedMessageStreamAsync(CancellationToken cancellationToken = default)
17+
=> _unsolicitedMessageStreamReadyTcs.Task.WaitAsync(cancellationToken);
18+
19+
internal void SignalUnsolicitedMessageStreamReady() => _unsolicitedMessageStreamReadyTcs.TrySetResult();
1420

1521
public async Task<ReconnectAttempt> TriggerFaultAsync(CancellationToken cancellationToken)
1622
{
@@ -24,6 +30,9 @@ public async Task<ReconnectAttempt> TriggerFaultAsync(CancellationToken cancella
2430
throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection.");
2531
}
2632

33+
// Reset the TCS so we can wait for the reconnected unsolicited message stream
34+
_unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
35+
2736
_reconnectTcs = new();
2837
await _lastStream.TriggerFaultAsync(cancellationToken);
2938

@@ -46,6 +55,7 @@ protected override async Task<HttpResponseMessage> SendAsync(
4655
_reconnectTcs = null;
4756
}
4857

58+
var isGetRequest = request.Method == HttpMethod.Get;
4959
var response = await base.SendAsync(request, cancellationToken);
5060

5161
// Only wrap SSE streams (text/event-stream)
@@ -63,6 +73,13 @@ protected override async Task<HttpResponseMessage> SendAsync(
6373
}
6474

6575
response.Content = newContent;
76+
77+
// For GET requests (unsolicited message stream), set up the stream to signal
78+
// when first data is read. This ensures the server's transport handler is ready.
79+
if (isGetRequest)
80+
{
81+
_lastStream.SetReadyCallback(SignalUnsolicitedMessageStreamReady);
82+
}
6683
}
6784

6885
return response;
@@ -89,10 +106,14 @@ private sealed class FaultingStream(Stream innerStream) : Stream
89106
{
90107
private readonly CancellationTokenSource _cts = new();
91108
private TaskCompletionSource? _faultTcs;
109+
private Action? _readyCallback;
110+
private bool _readySignaled;
92111
private bool _disposed;
93112

94113
public bool IsDisposed => _disposed;
95114

115+
public void SetReadyCallback(Action callback) => _readyCallback = callback;
116+
96117
public async Task TriggerFaultAsync(CancellationToken cancellationToken)
97118
{
98119
if (_faultTcs is not null)
@@ -131,6 +152,12 @@ public override async ValueTask<int> ReadAsync(Memory<byte> buffer, Cancellation
131152

132153
_cts.Token.ThrowIfCancellationRequested();
133154

155+
if (bytesRead > 0 && !_readySignaled)
156+
{
157+
_readySignaled = true;
158+
_readyCallback?.Invoke();
159+
}
160+
134161
return bytesRead;
135162
}
136163
catch (OperationCanceledException) when (_cts.IsCancellationRequested)

0 commit comments

Comments
 (0)