11import Atomics
2+ import NIOConcurrencyHelpers
23import NIOCore
34import NIOPosix
45#if canImport(Network)
@@ -129,33 +130,37 @@ public final class PostgresConnection: @unchecked Sendable {
129130 id connectionID: ID ,
130131 logger: Logger
131132 ) -> EventLoopFuture < PostgresConnection > {
132- self . connect (
133+ let ( future , _ ) = self . connect (
133134 connectionID: connectionID,
134135 configuration: . init( configuration) ,
135136 logger: logger,
136137 on: eventLoop
137138 )
139+ return future
138140 }
139141
140142 static func connect(
141143 connectionID: ID ,
142144 configuration: PostgresConnection . InternalConfiguration ,
143145 logger: Logger ,
144146 on eventLoop: any EventLoop
145- ) -> EventLoopFuture < PostgresConnection > {
147+ ) -> ( EventLoopFuture < PostgresConnection > , ConnectCancelHandler ) {
146148
147149 var mlogger = logger
148150 mlogger [ postgresMetadataKey: . connectionID] = " \( connectionID) "
149151 let logger = mlogger
150152
153+ let cancelHandler = ConnectCancelHandler ( )
154+ let deadline = NIODeadline . now ( ) + configuration. options. connectTimeout
155+
151156 // Here we dispatch to the `eventLoop` first before we setup the EventLoopFuture chain, to
152157 // ensure all `flatMap`s are executed on the EventLoop (this means the enqueuing of the
153158 // callbacks).
154159 //
155160 // This saves us a number of context switches between the thread the Connection is created
156161 // on and the EventLoop. In addition, it eliminates all potential races between the creating
157162 // thread and the EventLoop.
158- return eventLoop. flatSubmit { ( ) -> EventLoopFuture < PostgresConnection > in
163+ let future = eventLoop. flatSubmit { ( ) -> EventLoopFuture < PostgresConnection > in
159164 let connectFuture : EventLoopFuture < any Channel >
160165
161166 switch configuration. connection {
@@ -176,17 +181,48 @@ public final class PostgresConnection: @unchecked Sendable {
176181 }
177182
178183 return connectFuture. flatMap { channel -> EventLoopFuture < PostgresConnection > in
184+ // 1. check if the connection request was cancelled in the mean time.
185+ if let closeFuture = cancelHandler. channelConnected ( channel) {
186+ return closeFuture. flatMapThrowing { throw CancellationError ( ) }
187+ }
188+
189+ // 2. check if the deadline has elapsed
190+ let remaining = deadline - . now( )
191+ guard remaining > . nanoseconds( 0 ) else {
192+ channel. close ( mode: . all, promise: nil )
193+ return channel. closeFuture. flatMapThrowing {
194+ throw PSQLError . connectionError ( underlying: ChannelError . connectTimeout ( configuration. options. connectTimeout) )
195+ }
196+ }
197+
198+ // 3. setup time to enforce connect deadline
199+ let timeoutTask = eventLoop. scheduleTask ( deadline: deadline) {
200+ channel. pipeline. fireErrorCaught (
201+ ChannelError . connectTimeout ( configuration. options. connectTimeout)
202+ )
203+ }
204+
179205 let connection = PostgresConnection ( channel: channel, connectionID: connectionID, logger: logger)
180- return connection. start ( configuration: configuration) . map { _ in connection }
206+ return connection. start ( configuration: configuration) . map { _ in
207+ timeoutTask. cancel ( )
208+ return connection
209+ } . flatMapError { error in
210+ timeoutTask. cancel ( )
211+ return eventLoop. makeFailedFuture ( error)
212+ }
181213 } . flatMapErrorThrowing { error -> PostgresConnection in
182214 switch error {
183- case is PSQLError :
215+ case is PSQLError , is CancellationError :
184216 throw error
185217 default :
186218 throw PSQLError . connectionError ( underlying: error)
187219 }
188220 }
189221 }
222+
223+ future. whenComplete { _ in cancelHandler. postgresHandshakeDone ( ) }
224+
225+ return ( future, cancelHandler)
190226 }
191227
192228 static func makeBootstrap(
@@ -319,12 +355,13 @@ extension PostgresConnection {
319355 options: options
320356 )
321357
322- return PostgresConnection . connect (
358+ let ( future , _ ) = PostgresConnection . connect (
323359 connectionID: self . idGenerator. wrappingIncrementThenLoad ( ordering: . relaxed) ,
324360 configuration: configuration,
325361 logger: logger,
326362 on: eventLoop
327363 )
364+ return future
328365 } . flatMapErrorThrowing { error in
329366 throw error. asAppropriatePostgresError
330367 }
@@ -373,12 +410,17 @@ extension PostgresConnection {
373410 id connectionID: ID ,
374411 logger: Logger
375412 ) async throws -> PostgresConnection {
376- try await self . connect (
413+ let ( future , cancelHandler ) = self . connect (
377414 connectionID: connectionID,
378415 configuration: . init( configuration) ,
379416 logger: logger,
380417 on: eventLoop
381- ) . get ( )
418+ )
419+ return try await withTaskCancellationHandler {
420+ try await future. get ( )
421+ } onCancel: {
422+ cancelHandler. cancel ( )
423+ }
382424 }
383425
384426 /// Closes the connection to the server.
0 commit comments