@@ -29,8 +29,13 @@ package final class UnixSocketRelay: Sendable {
2929 private let log : Logger ?
3030 private let state : Mutex < State >
3131
32+ private struct ActiveRelay : Sendable {
33+ let relay : BidirectionalRelay
34+ let guestConnection : VsockConnection
35+ }
36+
3237 private struct State {
33- var activeRelays : [ String : BidirectionalRelay ] = [ : ]
38+ var activeRelays : [ String : ActiveRelay ] = [ : ]
3439 var t : Task < ( ) , Never > ? = nil
3540 var listener : VsockListener ? = nil
3641 }
@@ -75,10 +80,9 @@ extension UnixSocketRelay {
7580 }
7681 t. cancel ( )
7782 $0. t = nil
78- for (_, relay ) in $0. activeRelays {
79- relay. stop ( )
83+ for (_, activeRelay ) in $0. activeRelays {
84+ activeRelay . relay. stop ( )
8085 }
81- $0. activeRelays. removeAll ( )
8286
8387 switch configuration. direction {
8488 case . outOf:
@@ -170,12 +174,12 @@ extension UnixSocketRelay {
170174 " initiating connection from host to guest " ,
171175 metadata: [
172176 " vport " : " \( port) " ,
173- " hostFd " : " \( guestConn . fileDescriptor) " ,
174- " guestFd " : " \( hostConn . fileDescriptor) " ,
177+ " hostFd " : " \( hostConn . fileDescriptor) " ,
178+ " guestFd " : " \( guestConn . fileDescriptor) " ,
175179 ] )
176180 try await self . relay (
177181 hostConn: hostConn,
178- guestFd : guestConn. fileDescriptor
182+ guestConn : guestConn
179183 )
180184 } catch {
181185 log? . error ( " failed to relay between vsock \( port) and \( hostConn) " )
@@ -184,7 +188,7 @@ extension UnixSocketRelay {
184188 }
185189
186190 private func handleGuestVsockConn(
187- vsockConn: FileHandle ,
191+ vsockConn: VsockConnection ,
188192 hostConnectionPath: URL ,
189193 port: UInt32 ,
190194 log: Logger ?
@@ -207,7 +211,7 @@ extension UnixSocketRelay {
207211 do {
208212 try await self . relay (
209213 hostConn: hostSocket,
210- guestFd : vsockConn. fileDescriptor
214+ guestConn : vsockConn
211215 )
212216 } catch {
213217 log? . error ( " failed to relay between vsock \( port) and \( hostPath) " )
@@ -216,9 +220,13 @@ extension UnixSocketRelay {
216220
217221 private func relay(
218222 hostConn: Socket ,
219- guestFd : Int32
223+ guestConn : VsockConnection
220224 ) async throws {
221225 let hostFd = hostConn. fileDescriptor
226+ let guestFd = dup ( guestConn. fileDescriptor)
227+ if guestFd == - 1 {
228+ throw POSIXError . fromErrno ( )
229+ }
222230
223231 let relayID = UUID ( ) . uuidString
224232 let relay = BidirectionalRelay (
@@ -229,9 +237,21 @@ extension UnixSocketRelay {
229237 )
230238
231239 state. withLock {
232- $0. activeRelays [ relayID] = relay
240+ // Retain the original connection until the relay has fully completed.
241+ // The relay owns its duplicated fd and will close it itself.
242+ $0. activeRelays [ relayID] = ActiveRelay (
243+ relay: relay,
244+ guestConnection: guestConn
245+ )
233246 }
234247
235248 relay. start ( )
249+
250+ Task {
251+ await relay. waitForCompletion ( )
252+ let _ = self . state. withLock {
253+ $0. activeRelays. removeValue ( forKey: relayID)
254+ }
255+ }
236256 }
237257}
0 commit comments