Skip to content

Commit 3ca72ea

Browse files
Thibault Wittembergtwittemb
authored andcommitted
codex: fix cancellation and sendability
1 parent c2b32ca commit 3ca72ea

6 files changed

Lines changed: 147 additions & 28 deletions

File tree

Sources/AsyncChannels/AsyncBufferedChannel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import OrderedCollections
2929
/// sut.send(3)
3030
/// sut.finish()
3131
/// ```
32-
public final class AsyncBufferedChannel<Element>: AsyncSequence, Sendable {
32+
public final class AsyncBufferedChannel<Element: Sendable>: AsyncSequence, Sendable {
3333
public typealias Element = Element
3434
public typealias AsyncIterator = Iterator
3535

Sources/AsyncSubjects/AsyncPassthroughSubject.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
/// passthrough.send(2)
3131
/// passthrough.send(.finished)
3232
/// ```
33-
public final class AsyncPassthroughSubject<Element>: AsyncSubject {
33+
public final class AsyncPassthroughSubject<Element: Sendable>: AsyncSubject {
3434
public typealias Element = Element
3535
public typealias Failure = Never
3636
public typealias AsyncIterator = Iterator

Sources/Combiners/Merge/MergeStateMachine.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ struct MergeStateMachine<Element>: Sendable {
240240
}
241241
}
242242

243-
if case .termination = regulatedElement, case .element(.failure) = regulatedElement {
243+
if case .element(.failure) = regulatedElement {
244244
self.task.cancel()
245245
}
246246

Sources/Operators/AsyncMulticastSequence.swift

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,37 +105,43 @@ where Base.Element == Subject.Element, Subject.Failure == Error, Base.AsyncItera
105105
}
106106

107107
func next() async {
108-
await Task {
109-
let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in
110-
switch state {
111-
case .available(let iterator):
112-
state = .busy
113-
return (true, iterator)
114-
case .busy:
115-
return (false, nil)
116-
}
117-
}
118-
119-
guard canAccessBase, var iterator = iterator else { return }
120-
121-
let toSend: Result<Element?, Error>
122-
do {
123-
let element = try await iterator.next()
124-
toSend = .success(element)
125-
} catch {
126-
toSend = .failure(error)
108+
guard !Task.isCancelled else { return }
109+
110+
let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in
111+
switch state {
112+
case .available(let iterator):
113+
state = .busy
114+
return (true, iterator)
115+
case .busy:
116+
return (false, nil)
127117
}
118+
}
128119

120+
guard canAccessBase, var iterator = iterator else { return }
121+
defer {
129122
self.state.withCriticalRegion { state in
130123
state = .available(iterator)
131124
}
125+
}
132126

133-
switch toSend {
134-
case .success(.some(let element)): self.subject.send(element)
135-
case .success(.none): self.subject.send(.finished)
136-
case .failure(let error): self.subject.send(.failure(error))
137-
}
138-
}.value
127+
guard !Task.isCancelled else { return }
128+
129+
let toSend: Result<Element?, Error>
130+
do {
131+
let element = try await iterator.next()
132+
toSend = .success(element)
133+
} catch {
134+
guard !Task.isCancelled else { return }
135+
toSend = .failure(error)
136+
}
137+
138+
guard !Task.isCancelled else { return }
139+
140+
switch toSend {
141+
case .success(.some(let element)): self.subject.send(element)
142+
case .success(.none): self.subject.send(.finished)
143+
case .failure(let error): self.subject.send(.failure(error))
144+
}
139145
}
140146

141147
public func makeAsyncIterator() -> AsyncIterator {

Tests/Combiners/Merge/AsyncMergeSequenceTests.swift

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,34 @@ private struct TimedAsyncSequence<Element>: AsyncSequence, AsyncIteratorProtocol
4141
}
4242
}
4343

44+
private struct CancellationAwareSequence<Element>: AsyncSequence, AsyncIteratorProtocol {
45+
typealias Element = Element
46+
typealias AsyncIterator = CancellationAwareSequence
47+
48+
let onStart: @Sendable () -> Void
49+
let onCancel: @Sendable () -> Void
50+
var hasStarted = false
51+
52+
mutating func next() async throws -> Element? {
53+
if !hasStarted {
54+
hasStarted = true
55+
onStart()
56+
}
57+
58+
do {
59+
try await Task.sleep(nanoseconds: 5_000_000_000)
60+
return nil
61+
} catch {
62+
onCancel()
63+
return nil
64+
}
65+
}
66+
67+
func makeAsyncIterator() -> AsyncIterator {
68+
self
69+
}
70+
}
71+
4472
final class AsyncMergeSequenceTests: XCTestCase {
4573
func testMerge_merges_sequences_according_to_the_timeline_using_asyncSequences() async throws {
4674
// -- 0 ------------------------------- 1000 ----------------------------- 2000 -
@@ -306,4 +334,34 @@ final class AsyncMergeSequenceTests: XCTestCase {
306334

307335
task.cancel()
308336
}
337+
338+
func testMerge_cancels_other_bases_on_error() async {
339+
let baseStartedExpectation = expectation(description: "The blocking base has started")
340+
let baseCancelledExpectation = expectation(description: "The blocking base has been cancelled")
341+
342+
let blockingBase = CancellationAwareSequence<Int>(
343+
onStart: { baseStartedExpectation.fulfill() },
344+
onCancel: { baseCancelledExpectation.fulfill() }
345+
)
346+
let failingBase = TimedAsyncSequence(intervalInMills: [0, 0], sequence: [1, 2], indexOfError: 1)
347+
348+
let sut = merge(failingBase, blockingBase)
349+
var iterator = sut.makeAsyncIterator()
350+
351+
do {
352+
_ = try await iterator.next()
353+
} catch {
354+
XCTFail("The first element should not fail")
355+
}
356+
await fulfillment(of: [baseStartedExpectation], timeout: 1)
357+
358+
do {
359+
_ = try await iterator.next()
360+
XCTFail("The iteration should fail")
361+
} catch {
362+
XCTAssertEqual(error as? MockError, MockError(code: 1))
363+
}
364+
365+
await fulfillment(of: [baseCancelledExpectation], timeout: 1)
366+
}
309367
}

Tests/Operators/AsyncMulticastSequenceTests.swift

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,39 @@ private class SpyAsyncSequenceForNumberOfIterators<Element>: AsyncSequence {
3939
}
4040
}
4141

42+
private struct CancellationAwareSequence: AsyncSequence {
43+
typealias Element = Int
44+
typealias AsyncIterator = Iterator
45+
46+
let onStart: @Sendable () -> Void
47+
let onCancel: @Sendable () -> Void
48+
49+
func makeAsyncIterator() -> AsyncIterator {
50+
Iterator(onStart: self.onStart, onCancel: self.onCancel)
51+
}
52+
53+
struct Iterator: AsyncIteratorProtocol {
54+
let onStart: @Sendable () -> Void
55+
let onCancel: @Sendable () -> Void
56+
var hasStarted = false
57+
58+
mutating func next() async throws -> Int? {
59+
if !hasStarted {
60+
hasStarted = true
61+
onStart()
62+
}
63+
64+
do {
65+
try await Task.sleep(nanoseconds: 5_000_000_000)
66+
return nil
67+
} catch {
68+
onCancel()
69+
return nil
70+
}
71+
}
72+
}
73+
}
74+
4275
final class AsyncMulticastSequenceTests: XCTestCase {
4376
func test_multiple_loops_receive_elements_from_single_baseIterator() {
4477
let taskHaveIterators = expectation(description: "All tasks have their iterator")
@@ -156,4 +189,26 @@ final class AsyncMulticastSequenceTests: XCTestCase {
156189
XCTAssertEqual(error as? MockError, expectedError)
157190
}
158191
}
192+
193+
func test_multicast_cancels_upstream_when_consumer_cancels() async {
194+
let upstreamStartedExpectation = expectation(description: "Upstream started")
195+
let upstreamCancelledExpectation = expectation(description: "Upstream cancelled")
196+
197+
let base = CancellationAwareSequence(
198+
onStart: { upstreamStartedExpectation.fulfill() },
199+
onCancel: { upstreamCancelledExpectation.fulfill() }
200+
)
201+
let stream = AsyncThrowingPassthroughSubject<Int, Error>()
202+
let sut = base.multicast(stream).autoconnect()
203+
204+
let task = Task {
205+
var iterator = sut.makeAsyncIterator()
206+
_ = try? await iterator.next()
207+
}
208+
209+
await fulfillment(of: [upstreamStartedExpectation], timeout: 1)
210+
task.cancel()
211+
212+
await fulfillment(of: [upstreamCancelledExpectation], timeout: 1)
213+
}
159214
}

0 commit comments

Comments
 (0)