Skip to content

Commit 7f8b04d

Browse files
authored
Fix SSE cancellation issue (#9158)
1 parent 1ec64b2 commit 7f8b04d

2 files changed

Lines changed: 165 additions & 2 deletions

File tree

src/HotChocolate/Fusion/src/Core/Clients/DefaultHttpGraphQLSubscriptionClient.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ private async IAsyncEnumerable<GraphQLResponse> SubscribeInternalAsync(
4343
var request = new GraphQLHttpRequest(subgraphRequest, _config.EndpointUri);
4444
using var response = await _client.SendAsync(request, cancellationToken).ConfigureAwait(false);
4545

46-
await foreach (var result in response.ReadAsResultStreamAsync(cancellationToken).ConfigureAwait(false))
46+
var resultStream = response.ReadAsResultStreamAsync(cancellationToken);
47+
await using var resultEnumerator = resultStream.GetAsyncEnumerator(cancellationToken);
48+
49+
while (await resultEnumerator.MoveNextAsync().ConfigureAwait(false))
4750
{
48-
yield return new GraphQLResponse(result);
51+
yield return new GraphQLResponse(resultEnumerator.Current);
4952
}
5053
}
5154

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
using System.Net;
2+
using HotChocolate.Fusion.Clients;
3+
using HotChocolate.Fusion.Metadata;
4+
5+
namespace HotChocolate.Fusion;
6+
7+
public class DefaultHttpGraphQLSubscriptionClientTests
8+
{
9+
[Fact]
10+
public async Task SubscribeAsync_Passes_CancellationToken_To_Sse_Enumeration()
11+
{
12+
var sseStream = new ObservingSseStream();
13+
var response = new HttpResponseMessage(HttpStatusCode.OK)
14+
{
15+
Content = new StreamContent(sseStream),
16+
};
17+
response.Content.Headers.ContentType = new("text/event-stream");
18+
19+
using var httpClient = new HttpClient(new StaticResponseHandler(response));
20+
21+
var config = new HttpClientConfiguration(
22+
clientName: "test",
23+
subgraphName: "reviews",
24+
endpointUri: new Uri("http://localhost/graphql"));
25+
26+
await using var client = new DefaultHttpGraphQLSubscriptionClient(config, httpClient);
27+
28+
var request = new SubgraphGraphQLRequest(
29+
subgraph: "reviews",
30+
document: "subscription OnNewReview { onNewReview { body } }",
31+
variableValues: null,
32+
extensions: null);
33+
34+
using var cts = new CancellationTokenSource();
35+
await using var stream = client.SubscribeAsync(request, cts.Token).GetAsyncEnumerator();
36+
37+
var moveNext = stream.MoveNextAsync().AsTask();
38+
await sseStream.ReadStarted.Task.WaitAsync(TimeSpan.FromSeconds(2));
39+
40+
cts.Cancel();
41+
42+
var linked = await WaitUntilAsync(
43+
() => sseStream.CapturedToken.IsCancellationRequested,
44+
TimeSpan.FromSeconds(1));
45+
46+
Assert.True(linked, "SSE enumeration token is not linked to the caller cancellation token.");
47+
48+
sseStream.Release();
49+
await Task.WhenAny(moveNext, Task.Delay(TimeSpan.FromSeconds(2)));
50+
}
51+
52+
private static async Task<bool> WaitUntilAsync(Func<bool> condition, TimeSpan timeout)
53+
{
54+
var end = DateTime.UtcNow + timeout;
55+
56+
while (DateTime.UtcNow < end)
57+
{
58+
if (condition())
59+
{
60+
return true;
61+
}
62+
63+
await Task.Delay(20);
64+
}
65+
66+
return condition();
67+
}
68+
69+
private sealed class StaticResponseHandler(HttpResponseMessage response) : HttpMessageHandler
70+
{
71+
protected override Task<HttpResponseMessage> SendAsync(
72+
HttpRequestMessage request,
73+
CancellationToken cancellationToken)
74+
=> Task.FromResult(response);
75+
}
76+
77+
private sealed class ObservingSseStream : Stream
78+
{
79+
private readonly CancellationTokenSource _release = new();
80+
81+
public TaskCompletionSource ReadStarted { get; } =
82+
new(TaskCreationOptions.RunContinuationsAsynchronously);
83+
84+
public CancellationToken CapturedToken { get; private set; }
85+
86+
public override bool CanRead => true;
87+
88+
public override bool CanSeek => false;
89+
90+
public override bool CanWrite => false;
91+
92+
public override long Length => throw new NotSupportedException();
93+
94+
public override long Position
95+
{
96+
get => throw new NotSupportedException();
97+
set => throw new NotSupportedException();
98+
}
99+
100+
public override void Flush()
101+
{
102+
}
103+
104+
public override int Read(byte[] buffer, int offset, int count)
105+
=> throw new NotSupportedException();
106+
107+
public override long Seek(long offset, SeekOrigin origin)
108+
=> throw new NotSupportedException();
109+
110+
public override void SetLength(long value)
111+
=> throw new NotSupportedException();
112+
113+
public override void Write(byte[] buffer, int offset, int count)
114+
=> throw new NotSupportedException();
115+
116+
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
117+
=> BlockUntilCanceledOrReleasedAsync(cancellationToken);
118+
119+
public override Task<int> ReadAsync(
120+
byte[] buffer,
121+
int offset,
122+
int count,
123+
CancellationToken cancellationToken)
124+
=> ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
125+
126+
public void Release() => _release.Cancel();
127+
128+
protected override void Dispose(bool disposing)
129+
{
130+
if (disposing)
131+
{
132+
_release.Cancel();
133+
_release.Dispose();
134+
}
135+
136+
base.Dispose(disposing);
137+
}
138+
139+
private async ValueTask<int> BlockUntilCanceledOrReleasedAsync(CancellationToken cancellationToken)
140+
{
141+
CapturedToken = cancellationToken;
142+
ReadStarted.TrySetResult();
143+
144+
using var linked = CancellationTokenSource.CreateLinkedTokenSource(
145+
cancellationToken,
146+
_release.Token);
147+
148+
try
149+
{
150+
await Task.Delay(Timeout.InfiniteTimeSpan, linked.Token);
151+
}
152+
catch (OperationCanceledException)
153+
{
154+
// Cancellation is expected in this test.
155+
}
156+
157+
return 0;
158+
}
159+
}
160+
}

0 commit comments

Comments
 (0)