Skip to content

Commit ebccf1d

Browse files
authored
Merge pull request #339 from fsprojects/repo-assist/fix-issue-167-withCancellation-128d379345c09895
[Repo Assist] adds TaskSeq.withCancellation
2 parents c6e7e37 + de3e834 commit ebccf1d

File tree

5 files changed

+199
-0
lines changed

5 files changed

+199
-0
lines changed

release-notes.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Release notes:
1010
- adds TaskSeq.unfold and TaskSeq.unfoldAsync, #289
1111
- adds TaskSeq.chunkBySize (closes #258) and TaskSeq.windowed, #289
1212
- fixes: CancellationToken passed to GetAsyncEnumerator is now honored in MoveNextAsync, #179
13+
- adds TaskSeq.withCancellation, #167
1314

1415
0.6.0
1516
- fixes: async { for item in taskSeq do ... } no longer wraps exceptions in AggregateException, #129

src/FSharp.Control.TaskSeq.Test/FSharp.Control.TaskSeq.Test.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
<Compile Include="TaskSeq.Let.Tests.fs" />
7272
<Compile Include="TaskSeq.Using.Tests.fs" />
7373
<Compile Include="TaskSeq.CancellationToken.Tests.fs" />
74+
<Compile Include="TaskSeq.WithCancellation.Tests.fs" />
7475
</ItemGroup>
7576

7677
<ItemGroup>
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
module TaskSeq.Tests.``WithCancellation``
2+
3+
open System
4+
open System.Collections.Generic
5+
open System.Threading
6+
open System.Threading.Tasks
7+
8+
open Xunit
9+
open FsUnit.Xunit
10+
11+
open FSharp.Control
12+
13+
/// A simple IAsyncEnumerable whose GetAsyncEnumerator records the token it was called with.
14+
type TokenCapturingSeq<'T>(items: 'T list) =
15+
let mutable capturedToken = CancellationToken.None
16+
17+
member _.CapturedToken = capturedToken
18+
19+
interface IAsyncEnumerable<'T> with
20+
member _.GetAsyncEnumerator(ct) =
21+
capturedToken <- ct
22+
23+
let source = taskSeq {
24+
for x in items do
25+
yield x
26+
}
27+
28+
source.GetAsyncEnumerator(ct)
29+
30+
module ``Null check`` =
31+
32+
[<Fact>]
33+
let ``TaskSeq-withCancellation: null source throws ArgumentNullException`` () =
34+
assertNullArg
35+
<| fun () -> TaskSeq.withCancellation CancellationToken.None null
36+
37+
module ``Token threading`` =
38+
39+
[<Fact>]
40+
let ``TaskSeq-withCancellation: passes supplied token to GetAsyncEnumerator`` () = task {
41+
let source = TokenCapturingSeq([ 1; 2; 3 ])
42+
use cts = new CancellationTokenSource()
43+
44+
let wrapped = TaskSeq.withCancellation cts.Token (source :> IAsyncEnumerable<_>)
45+
let! _ = TaskSeq.toArrayAsync wrapped
46+
source.CapturedToken |> should equal cts.Token
47+
}
48+
49+
[<Fact>]
50+
let ``TaskSeq-withCancellation: overrides any token passed to GetAsyncEnumerator`` () = task {
51+
let source = TokenCapturingSeq([ 1; 2; 3 ])
52+
use cts = new CancellationTokenSource()
53+
54+
let wrapped = TaskSeq.withCancellation cts.Token (source :> IAsyncEnumerable<_>)
55+
56+
// Consume with a different token; withCancellation should win
57+
use outerCts = new CancellationTokenSource()
58+
let enum = wrapped.GetAsyncEnumerator(outerCts.Token)
59+
60+
while! enum.MoveNextAsync() do
61+
()
62+
63+
source.CapturedToken |> should equal cts.Token
64+
}
65+
66+
[<Fact>]
67+
let ``TaskSeq-withCancellation: CancellationToken.None passes through correctly`` () = task {
68+
let source = TokenCapturingSeq([ 10; 20 ])
69+
70+
let wrapped = TaskSeq.withCancellation CancellationToken.None (source :> IAsyncEnumerable<_>)
71+
let! _ = TaskSeq.toArrayAsync wrapped
72+
source.CapturedToken |> should equal CancellationToken.None
73+
}
74+
75+
module ``Cancellation behaviour`` =
76+
77+
[<Fact>]
78+
let ``TaskSeq-withCancellation: pre-cancelled token causes OperationCanceledException on iteration`` () = task {
79+
use cts = new CancellationTokenSource()
80+
cts.Cancel()
81+
82+
let source = taskSeq {
83+
while true do
84+
yield 1
85+
}
86+
87+
let wrapped = TaskSeq.withCancellation cts.Token source
88+
89+
fun () -> TaskSeq.iter ignore wrapped |> Task.ignore
90+
|> should throwAsync typeof<OperationCanceledException>
91+
}
92+
93+
[<Fact>]
94+
let ``TaskSeq-withCancellation: token cancelled mid-iteration raises OperationCanceledException`` () = task {
95+
use cts = new CancellationTokenSource()
96+
97+
let source = taskSeq {
98+
for i in 1..100 do
99+
yield i
100+
}
101+
102+
let wrapped = TaskSeq.withCancellation cts.Token source
103+
104+
fun () ->
105+
task {
106+
let mutable count = 0
107+
use enum = wrapped.GetAsyncEnumerator(CancellationToken.None)
108+
109+
while! enum.MoveNextAsync() do
110+
count <- count + 1
111+
112+
if count = 3 then
113+
cts.Cancel()
114+
}
115+
|> Task.ignore
116+
|> should throwAsync typeof<OperationCanceledException>
117+
}
118+
119+
module ``Sequence contents`` =
120+
121+
[<Fact>]
122+
let ``TaskSeq-withCancellation: empty source produces empty sequence`` () =
123+
TaskSeq.empty<int>
124+
|> TaskSeq.withCancellation CancellationToken.None
125+
|> verifyEmpty
126+
127+
[<Fact>]
128+
let ``TaskSeq-withCancellation: finite source produces all items`` () = task {
129+
let! result =
130+
taskSeq {
131+
for i in 1..10 do
132+
yield i
133+
}
134+
|> TaskSeq.withCancellation CancellationToken.None
135+
|> TaskSeq.toArrayAsync
136+
137+
result |> should equal [| 1..10 |]
138+
}
139+
140+
[<Fact>]
141+
let ``TaskSeq-withCancellation: can be used with TaskSeq combinators`` () = task {
142+
use cts = new CancellationTokenSource()
143+
144+
let! result =
145+
taskSeq {
146+
for i in 1..5 do
147+
yield i
148+
}
149+
|> TaskSeq.withCancellation cts.Token
150+
|> TaskSeq.map (fun x -> x * 2)
151+
|> TaskSeq.toArrayAsync
152+
153+
result |> should equal [| 2; 4; 6; 8; 10 |]
154+
}
155+
156+
[<Fact>]
157+
let ``TaskSeq-withCancellation: can be piped like .WithCancellation usage pattern`` () = task {
158+
use cts = new CancellationTokenSource()
159+
let mutable collected = ResizeArray()
160+
161+
let source = taskSeq {
162+
for i in 1..5 do
163+
yield i
164+
}
165+
166+
do!
167+
source
168+
|> TaskSeq.withCancellation cts.Token
169+
|> TaskSeq.iterAsync (fun x -> task { collected.Add(x) })
170+
171+
collected |> Seq.toArray |> should equal [| 1..5 |]
172+
}

src/FSharp.Control.TaskSeq/TaskSeq.fs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ type TaskSeq private () =
258258
yield c
259259
}
260260

261+
static member withCancellation (cancellationToken: CancellationToken) (source: TaskSeq<'T>) =
262+
Internal.checkNonNull (nameof source) source
263+
264+
{ new IAsyncEnumerable<'T> with
265+
member _.GetAsyncEnumerator(_ct) = source.GetAsyncEnumerator(cancellationToken)
266+
}
267+
261268
//
262269
// Utility functions
263270
//

src/FSharp.Control.TaskSeq/TaskSeq.fsi

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
namespace FSharp.Control
22

33
open System.Collections.Generic
4+
open System.Threading
45
open System.Threading.Tasks
56

67
[<AutoOpen>]
@@ -602,6 +603,23 @@ type TaskSeq =
602603
/// <exception cref="T:ArgumentNullException">Thrown when the input sequence is null.</exception>
603604
static member ofAsyncArray: source: Async<'T> array -> TaskSeq<'T>
604605

606+
/// <summary>
607+
/// Returns a task sequence that, when iterated, passes the given <paramref name="cancellationToken" /> to the
608+
/// underlying <see cref="IAsyncEnumerable&lt;'T&gt;" />. This is the equivalent of calling
609+
/// <c>.WithCancellation(cancellationToken)</c> on an <see cref="IAsyncEnumerable&lt;'T&gt;" />.
610+
/// </summary>
611+
/// <remarks>
612+
/// The <paramref name="cancellationToken" /> supplied to this function overrides any token that would otherwise
613+
/// be passed to the enumerator. This is useful when consuming sequences from libraries such as Entity Framework,
614+
/// which accept a <see cref="CancellationToken" /> through <c>GetAsyncEnumerator</c>.
615+
/// </remarks>
616+
///
617+
/// <param name="cancellationToken">The cancellation token to pass to <c>GetAsyncEnumerator</c>.</param>
618+
/// <param name="source">The input task sequence.</param>
619+
/// <returns>A task sequence that uses the given <paramref name="cancellationToken" /> when iterated.</returns>
620+
/// <exception cref="T:ArgumentNullException">Thrown when the input task sequence is null.</exception>
621+
static member withCancellation: cancellationToken: CancellationToken -> source: TaskSeq<'T> -> TaskSeq<'T>
622+
605623
/// <summary>
606624
/// Views each item in the input task sequence as <see cref="obj" />, boxing value types.
607625
/// </summary>

0 commit comments

Comments
 (0)