Skip to content

Commit d5e0b7b

Browse files
committed
- implement internal cancellation for SCAN via WithCancellation
1 parent 8a52783 commit d5e0b7b

4 files changed

Lines changed: 94 additions & 21 deletions

File tree

docs/AsyncTimeouts.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,18 @@ using var cts = CancellationTokenSource.CreateLinkedTokenSource(token); // or mu
6262
cts.CancelAfter(timeout);
6363
await database.StringSetAsync("key", "value").WaitAsync(cts.Token);
6464
var value = await database.StringGetAsync("key").WaitAsync(cts.Token);
65-
``````
65+
```
66+
67+
### Cancelling keys enumeration
68+
69+
Keys being enumerated (via `SCAN`) can *also* be cancelled, using the inbuilt `.WithCancellation(...)` method:
70+
71+
```csharp
72+
CancellationToken token = ...; // for example, from HttpContext.RequestAborted
73+
await foreach (var key in server.KeysAsync(pattern: "*foo*").WithCancellation(token))
74+
{
75+
...
76+
}
77+
```
78+
79+
To use a timeout instead, you can use the `CancellationTokenSource` approach shown above.

src/StackExchange.Redis/CursorEnumerable.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ private bool SimpleNext()
141141
{
142142
if (_pageOffset + 1 < _pageCount)
143143
{
144+
cancellationToken.ThrowIfCancellationRequested();
144145
_pageOffset++;
145146
return true;
146147
}
@@ -274,7 +275,7 @@ private async ValueTask<bool> AwaitedNextAsync(bool isInitial)
274275
ScanResult scanResult;
275276
try
276277
{
277-
scanResult = await pending.ForAwait();
278+
scanResult = await pending.WaitAsync(cancellationToken).ForAwait();
278279
}
279280
catch (Exception ex)
280281
{

src/StackExchange.Redis/TaskExtensions.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,44 @@ internal static Task<T> ObserveErrors<T>(this Task<T> task)
2525
return task;
2626
}
2727

28+
#if !NET6_0_OR_GREATER
29+
// suboptimal polyfill version of the .NET 6+ API, but reasonable for light use
30+
internal static Task<T> WaitAsync<T>(this Task<T> task, CancellationToken cancellationToken)
31+
{
32+
if (task.IsCompleted || !cancellationToken.CanBeCanceled) return task;
33+
return Wrap(task, cancellationToken);
34+
35+
static async Task<T> Wrap(Task<T> task, CancellationToken cancellationToken)
36+
{
37+
var tcs = new TaskSourceWithToken<T>(cancellationToken);
38+
using var reg = cancellationToken.Register(
39+
static state => ((TaskSourceWithToken<T>)state!).Cancel(), tcs);
40+
_ = task.ContinueWith(
41+
static (t, state) =>
42+
{
43+
var tcs = (TaskSourceWithToken<T>)state!;
44+
if (t.IsCanceled) tcs.TrySetCanceled();
45+
else if (t.IsFaulted) tcs.TrySetException(t.Exception!);
46+
else tcs.TrySetResult(t.Result);
47+
},
48+
tcs);
49+
return await tcs.Task;
50+
}
51+
}
52+
53+
// the point of this type is to combine TCS and CT so that we can use a static
54+
// registration via Register
55+
private sealed class TaskSourceWithToken<T> : TaskCompletionSource<T>
56+
{
57+
public TaskSourceWithToken(CancellationToken cancellationToken)
58+
=> _cancellationToken = cancellationToken;
59+
60+
private readonly CancellationToken _cancellationToken;
61+
62+
public void Cancel() => TrySetCanceled(_cancellationToken);
63+
}
64+
#endif
65+
2866
[MethodImpl(MethodImplOptions.AggressiveInlining)]
2967
internal static ConfiguredTaskAwaitable ForAwait(this Task task) => task.ConfigureAwait(false);
3068
[MethodImpl(MethodImplOptions.AggressiveInlining)]

tests/StackExchange.Redis.Tests/CancellationTests.cs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,6 @@ internal static class TaskExtensions
1212
{
1313
// suboptimal polyfill version of the .NET 6+ API; I'm not recommending this for production use,
1414
// but it's good enough for tests
15-
public static Task<T> WaitAsync<T>(this Task<T> task, CancellationToken cancellationToken)
16-
{
17-
if (task.IsCompleted || !cancellationToken.CanBeCanceled) return task;
18-
return Wrap(task, cancellationToken);
19-
20-
static async Task<T> Wrap(Task<T> task, CancellationToken cancellationToken)
21-
{
22-
var tcs = new TaskCompletionSource<T>();
23-
using var reg = cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken));
24-
_ = task.ContinueWith(t =>
25-
{
26-
if (t.IsCanceled) tcs.TrySetCanceled();
27-
else if (t.IsFaulted) tcs.TrySetException(t.Exception!);
28-
else tcs.TrySetResult(t.Result);
29-
});
30-
return await tcs.Task;
31-
}
32-
}
33-
3415
public static Task<T> WaitAsync<T>(this Task<T> task, TimeSpan timeout)
3516
{
3617
if (task.IsCompleted) return task;
@@ -92,6 +73,11 @@ private void Pause(IDatabase db)
9273
db.Execute("client", new object[] { "pause", ConnectionPauseMilliseconds }, CommandFlags.FireAndForget);
9374
}
9475

76+
private void Pause(IServer server)
77+
{
78+
server.Execute("client", new object[] { "pause", ConnectionPauseMilliseconds }, CommandFlags.FireAndForget);
79+
}
80+
9581
[Fact]
9682
public async Task WithTimeout_ShortTimeout_Async_ThrowsOperationCanceledException()
9783
{
@@ -195,4 +181,38 @@ public async Task CancellationDuringOperation_Async_CancelsGracefully(CancelStra
195181
Assert.Equal(cts.Token, oce.CancellationToken);
196182
}
197183
}
184+
185+
[Fact]
186+
public async Task ScanCancellable()
187+
{
188+
using var conn = Create();
189+
var db = conn.GetDatabase();
190+
var server = conn.GetServer(conn.GetEndPoints()[0]);
191+
192+
using var cts = new CancellationTokenSource();
193+
194+
var watch = Stopwatch.StartNew();
195+
Pause(server);
196+
try
197+
{
198+
db.StringSet(Me(), "value", TimeSpan.FromMinutes(5), flags: CommandFlags.FireAndForget);
199+
await using var iter = server.KeysAsync(pageSize: 1000).WithCancellation(cts.Token).GetAsyncEnumerator();
200+
var pending = iter.MoveNextAsync();
201+
Assert.False(cts.Token.IsCancellationRequested);
202+
cts.CancelAfter(ShortDelayMilliseconds); // start this *after* we've got past the initial check
203+
while (await pending)
204+
{
205+
pending = iter.MoveNextAsync();
206+
}
207+
Assert.Fail($"{ExpectedCancel}: {watch.ElapsedMilliseconds}ms");
208+
}
209+
catch (OperationCanceledException oce)
210+
{
211+
var taken = watch.ElapsedMilliseconds;
212+
// Expected if cancellation happens during operation
213+
Log($"Cancelled after {taken}ms");
214+
Assert.True(taken < ConnectionPauseMilliseconds / 2, "Should have cancelled much sooner");
215+
Assert.Equal(cts.Token, oce.CancellationToken);
216+
}
217+
}
198218
}

0 commit comments

Comments
 (0)