Skip to content

Commit acb4cc9

Browse files
authored
Release SSE response stream reference when GET request ends (#1519)
1 parent df0c102 commit acb4cc9

2 files changed

Lines changed: 141 additions & 29 deletions

File tree

src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public sealed partial class StreamableHttpServerTransport : ITransport
4444
private TaskCompletionSource<bool>? _httpResponseTcs;
4545
private string? _negotiatedProtocolVersion;
4646
private bool _getHttpRequestStarted;
47-
private bool _getHttpResponseCompleted;
47+
private bool _disposed;
4848

4949
/// <summary>
5050
/// Initializes a new instance of the <see cref="StreamableHttpServerTransport"/> class.
@@ -137,33 +137,53 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo
137137
throw new InvalidOperationException("GET requests are not supported in stateless mode.");
138138
}
139139

140-
using (await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false))
140+
try
141141
{
142-
if (_getHttpRequestStarted)
142+
using (await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false))
143143
{
144-
throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
145-
}
144+
if (_getHttpRequestStarted)
145+
{
146+
throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
147+
}
146148

147-
_getHttpRequestStarted = true;
148-
_httpSseWriter = new SseEventWriter(sseResponseStream);
149-
_httpResponseTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
150-
_storeSseWriter = await TryCreateEventStreamAsync(streamId: UnsolicitedMessageStreamId, cancellationToken).ConfigureAwait(false);
151-
if (_storeSseWriter is not null)
152-
{
153-
var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime<JsonRpcMessage>(), cancellationToken).ConfigureAwait(false);
154-
await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false);
149+
_getHttpRequestStarted = true;
150+
_httpSseWriter = new SseEventWriter(sseResponseStream);
151+
_httpResponseTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
152+
_storeSseWriter = await TryCreateEventStreamAsync(streamId: UnsolicitedMessageStreamId, cancellationToken).ConfigureAwait(false);
153+
if (_storeSseWriter is not null)
154+
{
155+
var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime<JsonRpcMessage>(), cancellationToken).ConfigureAwait(false);
156+
await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false);
157+
}
158+
else
159+
{
160+
// If there's no priming write, flush the stream to ensure HTTP response headers are
161+
// sent to the client now that the transport is ready to accept messages via SendMessageAsync.
162+
await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false);
163+
}
155164
}
156-
else
165+
166+
// Wait for the response to be written before returning from the handler.
167+
// This keeps the HTTP response open until the final response message is sent.
168+
await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
169+
}
170+
finally
171+
{
172+
// Release the SseEventWriter's reference to the response stream promptly when the GET
173+
// request ends, regardless of how it exits. Otherwise the response stream (and the
174+
// underlying Kestrel connection and associated memory pool buffers) remains pinned
175+
// in memory until the session itself is disposed (via explicit DELETE or idle timeout).
176+
// Clients that disconnect without sending DELETE — common with long-lived SSE — would
177+
// otherwise accumulate significant unmanaged memory per session during that interval.
178+
using (await _unsolicitedMessageLock.LockAsync(CancellationToken.None).ConfigureAwait(false))
157179
{
158-
// If there's no priming write, flush the stream to ensure HTTP response headers are
159-
// sent to the client now that the transport is ready to accept messages via SendMessageAsync.
160-
await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false);
180+
if (_httpSseWriter is { } writer)
181+
{
182+
_httpSseWriter = null;
183+
writer.Dispose();
184+
}
161185
}
162186
}
163-
164-
// Wait for the response to be written before returning from the handler.
165-
// This keeps the HTTP response open until the final response message is sent.
166-
await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
167187
}
168188

169189
/// <summary>
@@ -219,23 +239,22 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
219239
return;
220240
}
221241

222-
Debug.Assert(_httpSseWriter is not null);
223242
Debug.Assert(_httpResponseTcs is not null);
224243

225244
var item = SseItem.Message(message);
226245

227246
if (_storeSseWriter is not null)
228247
{
248+
// Always record the message in the event store (if configured) — even when the GET
249+
// response stream is gone — so a reconnecting client can replay it via Last-Event-ID.
229250
item = await _storeSseWriter.WriteEventAsync(item, cancellationToken).ConfigureAwait(false);
230251
}
231252

232-
if (!_getHttpResponseCompleted)
253+
if (_httpSseWriter is { } writer)
233254
{
234-
// Only write the message to the response if the response has not completed.
235-
236255
try
237256
{
238-
await _httpSseWriter!.WriteAsync(item, cancellationToken).ConfigureAwait(false);
257+
await writer.WriteAsync(item, cancellationToken).ConfigureAwait(false);
239258
}
240259
catch (Exception ex) when (!cancellationToken.IsCancellationRequested)
241260
{
@@ -249,12 +268,12 @@ public async ValueTask DisposeAsync()
249268
{
250269
using var _ = await _unsolicitedMessageLock.LockAsync().ConfigureAwait(false);
251270

252-
if (_getHttpResponseCompleted)
271+
if (_disposed)
253272
{
254273
return;
255274
}
256275

257-
_getHttpResponseCompleted = true;
276+
_disposed = true;
258277

259278
try
260279
{
@@ -266,7 +285,11 @@ public async ValueTask DisposeAsync()
266285
try
267286
{
268287
_httpResponseTcs?.TrySetResult(true);
269-
_httpSseWriter?.Dispose();
288+
if (_httpSseWriter is { } writer)
289+
{
290+
_httpSseWriter = null;
291+
writer.Dispose();
292+
}
270293

271294
if (_storeSseWriter is not null)
272295
{
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
using ModelContextProtocol.Protocol;
2+
using ModelContextProtocol.Server;
3+
using ModelContextProtocol.Tests.Utils;
4+
5+
namespace ModelContextProtocol.Tests.Transport;
6+
7+
public class StreamableHttpServerTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
8+
{
9+
[Fact]
10+
public async Task SendMessageAsync_AfterGetRequestEnds_DoesNotWriteToResponseStream()
11+
{
12+
// Regression test for the SSE response stream being retained after the GET request
13+
// handler returns. Without releasing the stream reference, the Kestrel connection
14+
// and its associated memory pool buffers (~20MiB per SSE session) stay pinned in
15+
// unmanaged memory until the session is eventually disposed (via explicit DELETE or
16+
// idle timeout), causing steady memory growth for servers whose clients disconnect
17+
// without sending DELETE. After the GET handler returns, SendMessageAsync must not
18+
// attempt to write to the (now released) response stream.
19+
20+
await using var transport = new StreamableHttpServerTransport()
21+
{
22+
SessionId = "test-session",
23+
};
24+
25+
var responseStream = new RecordingStream();
26+
27+
using var cts = new CancellationTokenSource();
28+
var getTask = transport.HandleGetRequestAsync(responseStream, cts.Token);
29+
30+
// Wait until the GET handler has finished initialization (signaled by the initial
31+
// flush that sends HTTP response headers) so we know _httpSseWriter is set.
32+
await responseStream.FirstActivity.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken);
33+
34+
var writeCountBeforeCancel = responseStream.WriteCount;
35+
36+
cts.Cancel();
37+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => getTask);
38+
39+
await transport.SendMessageAsync(
40+
new JsonRpcNotification { Method = "test" },
41+
TestContext.Current.CancellationToken);
42+
43+
Assert.Equal(writeCountBeforeCancel, responseStream.WriteCount);
44+
}
45+
46+
private sealed class RecordingStream : Stream
47+
{
48+
private readonly TaskCompletionSource<bool> _firstActivity = new(TaskCreationOptions.RunContinuationsAsynchronously);
49+
private int _writeCount;
50+
51+
public Task FirstActivity => _firstActivity.Task;
52+
public int WriteCount => Volatile.Read(ref _writeCount);
53+
54+
public override bool CanRead => false;
55+
public override bool CanSeek => false;
56+
public override bool CanWrite => true;
57+
public override long Length => throw new NotSupportedException();
58+
public override long Position
59+
{
60+
get => throw new NotSupportedException();
61+
set => throw new NotSupportedException();
62+
}
63+
64+
public override void Flush() => _firstActivity.TrySetResult(true);
65+
66+
public override Task FlushAsync(CancellationToken cancellationToken)
67+
{
68+
_firstActivity.TrySetResult(true);
69+
return Task.CompletedTask;
70+
}
71+
72+
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
73+
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
74+
public override void SetLength(long value) => throw new NotSupportedException();
75+
76+
public override void Write(byte[] buffer, int offset, int count)
77+
{
78+
Interlocked.Increment(ref _writeCount);
79+
_firstActivity.TrySetResult(true);
80+
}
81+
82+
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
83+
{
84+
Interlocked.Increment(ref _writeCount);
85+
_firstActivity.TrySetResult(true);
86+
return Task.CompletedTask;
87+
}
88+
}
89+
}

0 commit comments

Comments
 (0)