Skip to content

Commit aa17b19

Browse files
committed
Fix open and close bugs
1 parent 3b56a89 commit aa17b19

4 files changed

Lines changed: 203 additions & 27 deletions

File tree

Sources/WebSocket/SystemURLSession.swift

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,30 @@ private func configuration(with options: WebSocketOptions) -> URLSessionConfigur
5353
return config
5454
}
5555

56-
private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
56+
final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
5757
private struct Callbacks: Sendable {
5858
let onOpen: @Sendable () async -> Void
5959
let onClose: @Sendable (WebSocketCloseCode, Data?) async -> Void
6060
}
6161

62-
// `Dictionary<ObjectIdentifier(URLWebSocketTask): Callbacks>`
63-
private let state: Locked<[ObjectIdentifier: Callbacks]> = .init([:])
62+
private struct State: Sendable {
63+
var callbacks: [ObjectIdentifier: Callbacks] = [:]
64+
var callbackTasks: [ObjectIdentifier: Task<Void, Never>] = [:]
65+
}
66+
67+
private let state = Locked(State())
6468

6569
func set(
6670
onOpen: @escaping @Sendable () async -> Void,
6771
onClose: @escaping @Sendable (WebSocketCloseCode, Data?) async -> Void,
6872
for taskID: ObjectIdentifier
6973
) {
70-
state.access { $0[taskID] = .init(onOpen: onOpen, onClose: onClose) }
74+
state.access { state in
75+
state.callbacks[taskID] = .init(
76+
onOpen: onOpen,
77+
onClose: onClose
78+
)
79+
}
7180
}
7281

7382
func urlSession(
@@ -76,9 +85,8 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
7685
didOpenWithProtocol _: String?
7786
) {
7887
let taskID = ObjectIdentifier(webSocketTask)
79-
80-
if let onOpen = state.access({ $0[taskID]?.onOpen }) {
81-
Task { await onOpen() }
88+
enqueue(for: taskID) { callbacks in
89+
await callbacks.onOpen()
8290
}
8391
}
8492

@@ -89,9 +97,8 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
8997
reason: Data?
9098
) {
9199
let taskID = ObjectIdentifier(webSocketTask)
92-
93-
if let onClose = state.access({ $0[taskID]?.onClose }) {
94-
Task { await onClose(WebSocketCloseCode(closeCode), reason) }
100+
enqueue(for: taskID) { callbacks in
101+
await callbacks.onClose(WebSocketCloseCode(closeCode), reason)
95102
}
96103
}
97104

@@ -101,20 +108,36 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
101108
didCompleteWithError error: Error?
102109
) {
103110
let taskID = ObjectIdentifier(task)
111+
let closeCode: WebSocketCloseCode = error == nil ? .normalClosure : .abnormalClosure
112+
let reason = error.map { Data($0.localizedDescription.utf8) }
113+
enqueue(for: taskID, removeAfterwards: true) { callbacks in
114+
await callbacks.onClose(closeCode, reason)
115+
}
116+
}
104117

105-
if let onClose = state.access({ $0[taskID]?.onClose }) {
106-
Task { [weak self] in
107-
if let error {
108-
await onClose(
109-
.abnormalClosure,
110-
Data(error.localizedDescription.utf8)
111-
)
112-
} else {
113-
await onClose(.normalClosure, nil)
114-
}
118+
private func enqueue(
119+
for taskID: ObjectIdentifier,
120+
removeAfterwards: Bool = false,
121+
_ operation: @escaping @Sendable (Callbacks) async -> Void
122+
) {
123+
state.access { state in
124+
guard let callbacks = state.callbacks[taskID] else {
125+
return
126+
}
115127

116-
self?.state.access { _ = $0.removeValue(forKey: taskID) }
128+
let previousTask = state.callbackTasks[taskID]
129+
let task = Task { [weak self] in
130+
_ = await previousTask?.result
131+
await operation(callbacks)
132+
133+
guard removeAfterwards else { return }
134+
self?.state.access { state in
135+
_ = state.callbacks.removeValue(forKey: taskID)
136+
_ = state.callbackTasks.removeValue(forKey: taskID)
137+
}
117138
}
139+
140+
state.callbackTasks[taskID] = task
118141
}
119142
}
120143
}

Sources/WebSocket/SystemWebSocket.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ final actor SystemWebSocket: Publisher {
8686
do {
8787
try await didOpen.value
8888
} catch is CancellationError {
89-
doClose(closeCode: .cancelled, reason: Data("cancelled".utf8))
89+
throw CancellationError()
9090
} catch is TimeoutError {
9191
doClose(closeCode: .timeout, reason: Data("timeout".utf8))
9292
throw TimeoutError()
@@ -250,9 +250,12 @@ private extension SystemWebSocket {
250250
}
251251

252252
func doClose(closeCode: WebSocketCloseCode, reason: Data?) {
253+
let close = WebSocketClose(closeCode, reason)
254+
didOpen.fail(WebSocketError(closeCode, reason))
255+
253256
switch state {
254257
case .unopened:
255-
state = .closed(.init(closeCode, reason))
258+
state = .closed(close)
256259

257260
case let .connecting(ws), let .open(ws):
258261
os_log(
@@ -271,7 +274,6 @@ private extension SystemWebSocket {
271274
}
272275
}
273276

274-
let close = WebSocketClose(closeCode, nil)
275277
state = .closed(close)
276278
onClose(close)
277279
didClose?.resolve((code: closeCode, reason: reason))

Tests/WebSocketTests/Server/WebSocketServer.swift

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import Combine
22
import Foundation
33
import NIO
4+
import NIOWebSocket
45
import WebSocket
56
import WebSocketKit
67

7-
enum WebSocketServerOutput: Hashable {
8+
enum WebSocketServerOutput {
89
case message(WebSocketMessage)
910
case remoteClose
11+
case remoteCloseWithReason(WebSocketErrorCode, Data)
1012
}
1113

1214
final class WebSocketServer {
@@ -71,6 +73,15 @@ final class WebSocketServer {
7173
do { try ws.close(code: .goingAway).wait() }
7274
catch {}
7375

76+
case let .remoteCloseWithReason(code, reason):
77+
var buffer = ByteBufferAllocator().buffer(capacity: 2 + reason.count)
78+
buffer.write(webSocketErrorCode: code)
79+
buffer.writeBytes(reason)
80+
ws.send(
81+
raw: buffer.readableBytesView,
82+
opcode: .connectionClose
83+
)
84+
7485
case let .message(message):
7586
switch message {
7687
case let .data(data):

Tests/WebSocketTests/SystemWebSocketTests.swift

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import AsyncExtensions
12
import Combine
3+
import NIO
4+
import NIOWebSocket
25
import Synchronized
36
@testable import WebSocket
47
import XCTest
@@ -40,7 +43,7 @@ class SystemWebSocketTests: XCTestCase {
4043
onOpen: { XCTFail("Should not have opened") },
4144
onClose: { close in
4245
XCTAssertEqual(.abnormalClosure, close.code)
43-
XCTAssertNil(close.reason)
46+
XCTAssertNotNil(close.reason)
4447
ex.fulfill()
4548
}
4649
)
@@ -52,6 +55,52 @@ class SystemWebSocketTests: XCTestCase {
5255
XCTAssertTrue(isClosed)
5356
}
5457

58+
func testOpenCancellationThrowsCancellationError() async throws {
59+
let server = try HangingServer()
60+
defer { server.shutDown() }
61+
62+
let client = try await SystemWebSocket(
63+
request: request(server.port),
64+
options: .init(timeoutIntervalForRequest: 5)
65+
)
66+
67+
let openTask = Task {
68+
try await client.open()
69+
}
70+
71+
try await Task.sleep(nanoseconds: 50 * NSEC_PER_MSEC)
72+
openTask.cancel()
73+
74+
switch await openTask.result {
75+
case .success:
76+
XCTFail("Expected `open()` to throw `CancellationError`")
77+
78+
case let .failure(error):
79+
XCTAssertTrue(
80+
error is CancellationError,
81+
"Received wrong error: \(String(reflecting: error))"
82+
)
83+
}
84+
}
85+
86+
func testOpenThrowsConnectionErrorWhenServerIsUnreachable() async throws {
87+
let (server, client) = try await makeOfflineServerAndClient(
88+
timeoutIntervalForRequest: 0.2
89+
)
90+
defer { server.shutDown() }
91+
92+
do {
93+
try await client.open()
94+
XCTFail("Should not have opened")
95+
} catch is TimeoutError {
96+
XCTFail("Should surface the connection failure instead of timing out")
97+
} catch let error as WebSocketError {
98+
XCTAssertEqual(.abnormalClosure, error.closeCode)
99+
} catch {
100+
XCTFail("Received wrong error: \(error)")
101+
}
102+
}
103+
55104
func _testErrorWhenRemoteCloses() async throws {
56105
let errorEx = expectation(description: "Should have closed")
57106
let (server, client) = try await makeServerAndClient(
@@ -114,6 +163,53 @@ class SystemWebSocketTests: XCTestCase {
114163
await fulfillment(of: [secondCloseEx], timeout: 0.1)
115164
}
116165

166+
func testDelegateDoesNotReorderOpenAndCloseCallbacks() async throws {
167+
let delegate = Delegate()
168+
let session = URLSession(configuration: .ephemeral)
169+
defer { session.invalidateAndCancel() }
170+
171+
let task = session.webSocketTask(with: URL(string: "ws://127.0.0.1/socket")!)
172+
let openStarted = AsyncThrowingFuture<Void>(timeout: 2)
173+
let allowOpenToFinish = AsyncThrowingFuture<Void>(timeout: 2)
174+
let records = Locked([String]())
175+
176+
delegate.set(
177+
onOpen: {
178+
records.access { $0.append("open-started") }
179+
openStarted.resolve()
180+
do { try await allowOpenToFinish.value }
181+
catch { XCTFail() }
182+
records.access { $0.append("open-finished") }
183+
},
184+
onClose: { _, _ in
185+
records.access { $0.append("close") }
186+
},
187+
for: ObjectIdentifier(task)
188+
)
189+
190+
delegate.urlSession(session, webSocketTask: task, didOpenWithProtocol: nil)
191+
try await openStarted.value
192+
193+
delegate.urlSession(
194+
session,
195+
webSocketTask: task,
196+
didCloseWith: .goingAway,
197+
reason: nil
198+
)
199+
200+
try await Task.sleep(nanoseconds: 10 * NSEC_PER_MSEC)
201+
let eventsBeforeOpenFinishes = records.access { $0 }
202+
XCTAssertEqual(["open-started"], eventsBeforeOpenFinishes)
203+
204+
allowOpenToFinish.resolve()
205+
try await Task.sleep(nanoseconds: 10 * NSEC_PER_MSEC)
206+
let eventsAfterOpenFinishes = records.access { $0 }
207+
XCTAssertEqual(
208+
["open-started", "open-finished", "close"],
209+
eventsAfterOpenFinishes
210+
)
211+
}
212+
117213
func testPushAndReceiveText() async throws {
118214
let (server, client) = try await makeServerAndClient()
119215
defer { server.shutDown() }
@@ -338,9 +434,27 @@ class SystemWebSocketTests: XCTestCase {
338434
}
339435
}
340436

437+
await fulfillment(of: [closeEx], timeout: 2)
438+
341439
XCTAssertEqual(3, messagesReceivedByClient)
342440
XCTAssertEqual(3, messagesReceivedByServer)
441+
}
343442

443+
func testRemoteCloseReasonIsPassedToOnClose() async throws {
444+
let closeEx = expectation(description: "Should expose the close reason")
445+
let reason = Data("server said goodbye".utf8)
446+
447+
let (server, client) = try await makeServerAndClient(
448+
onClose: { close in
449+
XCTAssertEqual(.goingAway, close.code)
450+
XCTAssertEqual(reason, close.reason)
451+
closeEx.fulfill()
452+
}
453+
)
454+
defer { server.shutDown() }
455+
456+
try await client.open()
457+
subject.send(.remoteCloseWithReason(.goingAway, reason))
344458
await fulfillment(of: [closeEx], timeout: 2)
345459
}
346460
}
@@ -373,13 +487,14 @@ private extension SystemWebSocketTests {
373487
}
374488

375489
func makeOfflineServerAndClient(
490+
timeoutIntervalForRequest: TimeInterval = 2,
376491
onOpen: @escaping @Sendable () -> Void = {},
377492
onClose: @escaping @Sendable (WebSocketClose) -> Void = { _ in }
378493
) async throws -> (WebSocketServer, SystemWebSocket) {
379494
let server = try WebSocketServer(outputPublisher: empty)
380495
let client = try! await SystemWebSocket(
381496
request: request(19),
382-
options: .init(timeoutIntervalForRequest: 2),
497+
options: .init(timeoutIntervalForRequest: timeoutIntervalForRequest),
383498
onOpen: onOpen,
384499
onClose: onClose
385500
)
@@ -400,3 +515,28 @@ private extension SystemWebSocketTests {
400515
return (server, try! await .system(client))
401516
}
402517
}
518+
519+
private final class HangingServer {
520+
var port: Int { channel!.localAddress!.port! }
521+
522+
private let eventLoopGroup: EventLoopGroup
523+
private var channel: Channel?
524+
525+
init() throws {
526+
eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
527+
channel = try ServerBootstrap(group: eventLoopGroup)
528+
.serverChannelOption(ChannelOptions.backlog, value: 256)
529+
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
530+
.childChannelInitializer { channel in
531+
channel.eventLoop.makeSucceededFuture(())
532+
}
533+
.childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
534+
.bind(host: "127.0.0.1", port: 0)
535+
.wait()
536+
}
537+
538+
func shutDown() {
539+
try? channel?.close(mode: .all).wait()
540+
try? eventLoopGroup.syncShutdownGracefully()
541+
}
542+
}

0 commit comments

Comments
 (0)