Skip to content

Commit baff615

Browse files
authored
Cancellation works when talking to a silent server (#647)
1 parent f2188e0 commit baff615

7 files changed

Lines changed: 238 additions & 11 deletions

File tree

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import NIOConcurrencyHelpers
2+
import NIOCore
3+
4+
final class ConnectCancelHandler: Sendable {
5+
private enum State {
6+
case waiting
7+
case cancelledBeforeConnect
8+
case tcpConnected(any Channel)
9+
case done
10+
}
11+
12+
private let state: NIOLockedValueBox<State>
13+
14+
init() {
15+
self.state = NIOLockedValueBox(.waiting)
16+
}
17+
18+
enum ChannelConnectedAction {
19+
case none
20+
case close(any Channel)
21+
}
22+
23+
func channelConnected(_ channel: any Channel) -> EventLoopFuture<Void>? {
24+
let action = self.state.withLockedValue { state -> ChannelConnectedAction in
25+
switch state {
26+
case .waiting:
27+
state = .tcpConnected(channel)
28+
return .none
29+
case .cancelledBeforeConnect:
30+
state = .done
31+
return .close(channel)
32+
case .tcpConnected, .done:
33+
preconditionFailure("channelConnected called in invalid state")
34+
}
35+
}
36+
37+
switch action {
38+
case .none:
39+
return nil
40+
case .close(let channel):
41+
channel.close(mode: .all, promise: nil)
42+
return channel.closeFuture
43+
}
44+
}
45+
46+
enum CancelAction {
47+
case none
48+
case close(any Channel)
49+
}
50+
51+
func cancel() -> EventLoopFuture<Void>? {
52+
let action = self.state.withLockedValue { state -> CancelAction in
53+
switch state {
54+
case .waiting:
55+
state = .cancelledBeforeConnect
56+
return .none
57+
case .tcpConnected(let channel):
58+
state = .done
59+
return .close(channel)
60+
case .cancelledBeforeConnect, .done:
61+
return .none
62+
}
63+
}
64+
65+
switch action {
66+
case .none:
67+
return nil
68+
case .close(let channel):
69+
channel.close(mode: .all, promise: nil)
70+
return channel.closeFuture
71+
}
72+
}
73+
74+
func postgresHandshakeDone() {
75+
self.state.withLockedValue { state in
76+
state = .done
77+
}
78+
}
79+
}

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Atomics
2+
import NIOConcurrencyHelpers
23
import NIOCore
34
import NIOPosix
45
#if canImport(Network)
@@ -129,33 +130,37 @@ public final class PostgresConnection: @unchecked Sendable {
129130
id connectionID: ID,
130131
logger: Logger
131132
) -> EventLoopFuture<PostgresConnection> {
132-
self.connect(
133+
let (future, _) = self.connect(
133134
connectionID: connectionID,
134135
configuration: .init(configuration),
135136
logger: logger,
136137
on: eventLoop
137138
)
139+
return future
138140
}
139141

140142
static func connect(
141143
connectionID: ID,
142144
configuration: PostgresConnection.InternalConfiguration,
143145
logger: Logger,
144146
on eventLoop: any EventLoop
145-
) -> EventLoopFuture<PostgresConnection> {
147+
) -> (EventLoopFuture<PostgresConnection>, ConnectCancelHandler) {
146148

147149
var mlogger = logger
148150
mlogger[postgresMetadataKey: .connectionID] = "\(connectionID)"
149151
let logger = mlogger
150152

153+
let cancelHandler = ConnectCancelHandler()
154+
let deadline = NIODeadline.now() + configuration.options.connectTimeout
155+
151156
// Here we dispatch to the `eventLoop` first before we setup the EventLoopFuture chain, to
152157
// ensure all `flatMap`s are executed on the EventLoop (this means the enqueuing of the
153158
// callbacks).
154159
//
155160
// This saves us a number of context switches between the thread the Connection is created
156161
// on and the EventLoop. In addition, it eliminates all potential races between the creating
157162
// thread and the EventLoop.
158-
return eventLoop.flatSubmit { () -> EventLoopFuture<PostgresConnection> in
163+
let future = eventLoop.flatSubmit { () -> EventLoopFuture<PostgresConnection> in
159164
let connectFuture: EventLoopFuture<any Channel>
160165

161166
switch configuration.connection {
@@ -176,17 +181,48 @@ public final class PostgresConnection: @unchecked Sendable {
176181
}
177182

178183
return connectFuture.flatMap { channel -> EventLoopFuture<PostgresConnection> in
184+
// 1. check if the connection request was cancelled in the mean time.
185+
if let closeFuture = cancelHandler.channelConnected(channel) {
186+
return closeFuture.flatMapThrowing { throw CancellationError() }
187+
}
188+
189+
// 2. check if the deadline has elapsed
190+
let remaining = deadline - .now()
191+
guard remaining > .nanoseconds(0) else {
192+
channel.close(mode: .all, promise: nil)
193+
return channel.closeFuture.flatMapThrowing {
194+
throw PSQLError.connectionError(underlying: ChannelError.connectTimeout(configuration.options.connectTimeout))
195+
}
196+
}
197+
198+
// 3. setup time to enforce connect deadline
199+
let timeoutTask = eventLoop.scheduleTask(deadline: deadline) {
200+
channel.pipeline.fireErrorCaught(
201+
ChannelError.connectTimeout(configuration.options.connectTimeout)
202+
)
203+
}
204+
179205
let connection = PostgresConnection(channel: channel, connectionID: connectionID, logger: logger)
180-
return connection.start(configuration: configuration).map { _ in connection }
206+
return connection.start(configuration: configuration).map { _ in
207+
timeoutTask.cancel()
208+
return connection
209+
}.flatMapError { error in
210+
timeoutTask.cancel()
211+
return eventLoop.makeFailedFuture(error)
212+
}
181213
}.flatMapErrorThrowing { error -> PostgresConnection in
182214
switch error {
183-
case is PSQLError:
215+
case is PSQLError, is CancellationError:
184216
throw error
185217
default:
186218
throw PSQLError.connectionError(underlying: error)
187219
}
188220
}
189221
}
222+
223+
future.whenComplete { _ in cancelHandler.postgresHandshakeDone() }
224+
225+
return (future, cancelHandler)
190226
}
191227

192228
static func makeBootstrap(
@@ -319,12 +355,13 @@ extension PostgresConnection {
319355
options: options
320356
)
321357

322-
return PostgresConnection.connect(
358+
let (future, _) = PostgresConnection.connect(
323359
connectionID: self.idGenerator.wrappingIncrementThenLoad(ordering: .relaxed),
324360
configuration: configuration,
325361
logger: logger,
326362
on: eventLoop
327363
)
364+
return future
328365
}.flatMapErrorThrowing { error in
329366
throw error.asAppropriatePostgresError
330367
}
@@ -373,12 +410,17 @@ extension PostgresConnection {
373410
id connectionID: ID,
374411
logger: Logger
375412
) async throws -> PostgresConnection {
376-
try await self.connect(
413+
let (future, cancelHandler) = self.connect(
377414
connectionID: connectionID,
378415
configuration: .init(configuration),
379416
logger: logger,
380417
on: eventLoop
381-
).get()
418+
)
419+
return try await withTaskCancellationHandler {
420+
try await future.get()
421+
} onCancel: {
422+
cancelHandler.cancel()
423+
}
382424
}
383425

384426
/// Closes the connection to the server.

Sources/PostgresNIO/New/PSQLEventsHandler.swift

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ final class PSQLEventsHandler: ChannelInboundHandler {
3838
case connected
3939
case readyForStartup
4040
case authenticated
41+
case closed
4142
}
4243

4344
private var readyForStartupPromise: EventLoopPromise<Void>!
@@ -65,7 +66,7 @@ final class PSQLEventsHandler: ChannelInboundHandler {
6566
// successful
6667
self.state = .authenticated
6768
self.authenticatePromise.succeed(Void())
68-
case .authenticated:
69+
case .authenticated, .closed:
6970
break
7071
}
7172
default:
@@ -89,16 +90,37 @@ final class PSQLEventsHandler: ChannelInboundHandler {
8990
context.fireChannelActive()
9091
}
9192

93+
func channelInactive(context: ChannelHandlerContext) {
94+
switch self.state {
95+
case .initialized:
96+
self.state = .closed
97+
self.readyForStartupPromise?.fail(PSQLError.clientClosedConnection(underlying: nil))
98+
self.authenticatePromise?.fail(PSQLError.clientClosedConnection(underlying: nil))
99+
case .connected:
100+
self.state = .closed
101+
self.readyForStartupPromise.fail(PSQLError.clientClosedConnection(underlying: nil))
102+
self.authenticatePromise.fail(PSQLError.clientClosedConnection(underlying: nil))
103+
case .readyForStartup:
104+
self.state = .closed
105+
self.authenticatePromise.fail(PSQLError.clientClosedConnection(underlying: nil))
106+
case .authenticated, .closed:
107+
break
108+
}
109+
context.fireChannelInactive()
110+
}
111+
92112
func errorCaught(context: ChannelHandlerContext, error: any Error) {
93113
switch self.state {
94114
case .initialized:
95115
preconditionFailure("Unexpected message for state")
96116
case .connected:
117+
self.state = .closed
97118
self.readyForStartupPromise.fail(error)
98119
self.authenticatePromise.fail(error)
99120
case .readyForStartup:
121+
self.state = .closed
100122
self.authenticatePromise.fail(error)
101-
case .authenticated:
123+
case .authenticated, .closed:
102124
break
103125
}
104126

Sources/PostgresNIO/Pool/ConnectionFactory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ final class ConnectionFactory: Sendable {
4646
configuration: config,
4747
id: connectionID,
4848
logger: connectionLogger
49-
).get()
49+
)
5050
}
5151

5252
func makeConnectionConfig() async throws -> PostgresConnection.Configuration {
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import Logging
2+
import NIOCore
3+
import PostgresNIO
4+
import Testing
5+
6+
@Suite struct PostgresClientTests {
7+
8+
@Test(.timeLimit(.minutes(1)))
9+
func clientRunExitsPromptlyOnCancellationWhileConnecting() async throws {
10+
try await withSilentServer { port in
11+
var config = PostgresClient.Configuration(
12+
host: "127.0.0.1",
13+
port: port,
14+
username: "postgres",
15+
password: "irrelevant",
16+
database: "test",
17+
tls: .disable
18+
)
19+
// Long connect timeout so the hang is obvious if the fix regresses.
20+
config.options.connectTimeout = .seconds(30)
21+
config.options.minimumConnections = 1
22+
23+
let client = PostgresClient(
24+
configuration: config,
25+
eventLoopGroup: NIOSingletons.posixEventLoopGroup,
26+
backgroundLogger: Logger(label: "test")
27+
)
28+
29+
await withTaskGroup(of: Void.self) { group in
30+
group.addTask { await client.run() }
31+
32+
// Give the pool enough time to start the connection attempt (enter .starting state).
33+
try? await Task.sleep(for: .milliseconds(200))
34+
35+
// Cancelling the group must make pool.run() return quickly — not in 30 seconds.
36+
group.cancelAll()
37+
}
38+
// If we reach here the test passed (the .timeLimit above enforces the deadline).
39+
}
40+
}
41+
}

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,34 @@ import Synchronization
10421042
}
10431043
}
10441044

1045+
@Test(.timeLimit(.minutes(1))) func connectEnforcesDeadlineWithSilentServer() async throws {
1046+
try await withSilentServer { port in
1047+
var config = PostgresConnection.Configuration(
1048+
host: "127.0.0.1",
1049+
port: port,
1050+
username: "postgres",
1051+
password: "irrelevant",
1052+
database: "test",
1053+
tls: .disable
1054+
)
1055+
config.options.connectTimeout = .milliseconds(500)
1056+
1057+
let start = ContinuousClock.now
1058+
1059+
await #expect(throws: PSQLError.self) {
1060+
_ = try await PostgresConnection.connect(
1061+
configuration: config,
1062+
id: 1,
1063+
logger: Logger(label: "test")
1064+
)
1065+
}
1066+
1067+
let elapsed = ContinuousClock.now - start
1068+
#expect(elapsed < .seconds(5))
1069+
#expect(elapsed >= .milliseconds(400))
1070+
}
1071+
}
1072+
10451073
func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws {
10461074
let eventLoop = NIOAsyncTestingEventLoop()
10471075
let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in

Tests/PostgresNIOTests/Utilities.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import NIOCore
2+
import NIOPosix
13
import Logging
24

35
extension Logger {
@@ -7,3 +9,16 @@ extension Logger {
79
return logger
810
}
911
}
12+
13+
func withSilentServer<Success: ~Copyable>(_ body: (_ port: Int) async throws -> Success) async throws -> Success{
14+
let server = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup)
15+
.bind(to: .init(ipAddress: "127.0.0.1", port: 0)).get()
16+
let result: Result<Success, any Error>
17+
do {
18+
result = .success(try await body(server.localAddress!.port!))
19+
} catch {
20+
result = .failure(error)
21+
}
22+
try? await server.close().get()
23+
return try result.get()
24+
}

0 commit comments

Comments
 (0)