@@ -86,17 +86,11 @@ package final class SocketRelay: Sendable {
8686 private let state : Mutex < State >
8787
8888 private struct State {
89- var relaySources : [ String : ConnectionSources ] = [ : ]
89+ var activeRelays : [ String : BidirectionalRelay ] = [ : ]
9090 var t : Task < ( ) , Never > ? = nil
9191 var listener : VsockListener ? = nil
9292 }
9393
94- // `DispatchSourceRead` is thread-safe.
95- private struct ConnectionSources : @unchecked Sendable {
96- let hostSource : DispatchSourceRead
97- let guestSource : DispatchSourceRead
98- }
99-
10094 init (
10195 port: UInt32 ,
10296 socket: UnixSocketConfiguration ,
@@ -137,7 +131,10 @@ extension SocketRelay {
137131 }
138132 t. cancel ( )
139133 $0. t = nil
140- $0. relaySources. removeAll ( )
134+ for (_, relay) in $0. activeRelays {
135+ relay. stop ( )
136+ }
137+ $0. activeRelays. removeAll ( )
141138
142139 switch configuration. direction {
143140 case . outOf:
@@ -227,7 +224,7 @@ extension SocketRelay {
227224 ) async throws {
228225 do {
229226 let guestConn = try await vm. dial ( port)
230- log? . info (
227+ log? . debug (
231228 " initiating connection from host to guest " ,
232229 metadata: [
233230 " vport " : " \( port) " ,
@@ -256,7 +253,7 @@ extension SocketRelay {
256253 type: socketType,
257254 closeOnDeinit: false
258255 )
259- log? . info (
256+ log? . debug (
260257 " initiating connection from guest to host " ,
261258 metadata: [
262259 " vport " : " \( port) " ,
@@ -279,210 +276,20 @@ extension SocketRelay {
279276 hostConn: Socket ,
280277 guestFd: Int32
281278 ) async throws {
282- // set up the source for host to guest transfers
283- let connSource = DispatchSource . makeReadSource (
284- fileDescriptor: hostConn. fileDescriptor,
285- queue: self . q
286- )
279+ let hostFd = hostConn. fileDescriptor
287280
288- // set up the source for guest to host transfers
289- let vsockConnectionSource = DispatchSource . makeReadSource (
290- fileDescriptor: guestFd,
291- queue: self . q
281+ let relayID = UUID ( ) . uuidString
282+ let relay = BidirectionalRelay (
283+ fd1: hostFd,
284+ fd2: guestFd,
285+ queue: self . q,
286+ log: self . log
292287 )
293288
294- // add the sources to the connection map
295- let pairID = UUID ( ) . uuidString
296289 self . state. withLock {
297- $0. relaySources [ pairID] = ConnectionSources (
298- hostSource: connSource,
299- guestSource: vsockConnectionSource
300- )
301- }
302-
303- // `buf1` is thread-safe because it is only used when servicing a serial dispatch queue
304- nonisolated ( unsafe) let buf1 = UnsafeMutableBufferPointer< UInt8> . allocate( capacity: Int ( getpagesize ( ) ) )
305- connSource. setEventHandler {
306- Self . fdCopyHandler (
307- buffer: buf1,
308- source: connSource,
309- from: hostConn. fileDescriptor,
310- to: guestFd,
311- log: self . log
312- )
313- }
314-
315- // `buf2` is thread-safe because it is only used when servicing a serial dispatch queue
316- nonisolated ( unsafe) let buf2 = UnsafeMutableBufferPointer< UInt8> . allocate( capacity: Int ( getpagesize ( ) ) )
317- vsockConnectionSource. setEventHandler {
318- Self . fdCopyHandler (
319- buffer: buf2,
320- source: vsockConnectionSource,
321- from: guestFd,
322- to: hostConn. fileDescriptor,
323- log: self . log
324- )
325- }
326-
327- connSource. setCancelHandler {
328- self . log? . debug (
329- " host cancel received " ,
330- metadata: [
331- " hostFd " : " \( hostConn. fileDescriptor) " ,
332- " guestFd " : " \( guestFd) " ,
333- ] )
334-
335- // only close underlying fds when both sources are at EOF
336- // ensure that one of the cancel handlers will see both sources cancelled
337- self . state. withLock { _ in
338- connSource. cancel ( )
339- if vsockConnectionSource. isCancelled {
340- self . log? . info (
341- " close file descriptors " ,
342- metadata: [
343- " hostFd " : " \( hostConn. fileDescriptor) " ,
344- " guestFd " : " \( guestFd) " ,
345- ] )
346- try ? hostConn. close ( )
347- close ( guestFd)
348- }
349- }
350- }
351-
352- vsockConnectionSource. setCancelHandler {
353- self . log? . debug (
354- " guest cancel received " ,
355- metadata: [
356- " hostFd " : " \( hostConn. fileDescriptor) " ,
357- " guestFd " : " \( guestFd) " ,
358- ] )
359-
360- // only close underlying fds when both sources are at EOF
361- // ensure that one of the cancel handlers will see both sources cancelled
362- self . state. withLock { _ in
363- vsockConnectionSource. cancel ( )
364- if connSource. isCancelled {
365- self . log? . info (
366- " close file descriptors " ,
367- metadata: [
368- " hostFd " : " \( hostConn. fileDescriptor) " ,
369- " guestFd " : " \( guestFd) " ,
370- ] )
371- try ? hostConn. close ( )
372- close ( guestFd)
373- }
374- }
290+ $0. activeRelays [ relayID] = relay
375291 }
376292
377- connSource. activate ( )
378- vsockConnectionSource. activate ( )
379- }
380-
381- private static func fdCopyHandler(
382- buffer: UnsafeMutableBufferPointer < UInt8 > ,
383- source: DispatchSourceRead ,
384- from sourceFd: Int32 ,
385- to destinationFd: Int32 ,
386- log: Logger ? = nil
387- ) {
388- if source. data == 0 {
389- log? . debug (
390- " source EOF " ,
391- metadata: [
392- " sourceFd " : " \( sourceFd) " ,
393- " dstFd " : " \( destinationFd) " ,
394- ] )
395- if !source. isCancelled {
396- log? . debug (
397- " canceling DispatchSourceRead " ,
398- metadata: [
399- " sourceFd " : " \( sourceFd) " ,
400- " dstFd " : " \( destinationFd) " ,
401- ] )
402- source. cancel ( )
403- if shutdown ( destinationFd, Int32 ( SHUT_WR) ) != 0 {
404- log? . warning (
405- " failed to shut down reads " ,
406- metadata: [
407- " errno " : " \( errno) " ,
408- " sourceFd " : " \( sourceFd) " ,
409- " dstFd " : " \( destinationFd) " ,
410- ]
411- )
412- }
413- }
414- return
415- }
416-
417- do {
418- log? . trace (
419- " source copy " ,
420- metadata: [
421- " sourceFd " : " \( sourceFd) " ,
422- " dstFd " : " \( destinationFd) " ,
423- " size " : " \( source. data) " ,
424- ] )
425- try self . fileDescriptorCopy (
426- buffer: buffer,
427- size: source. data,
428- from: sourceFd,
429- to: destinationFd
430- )
431- } catch {
432- log? . error ( " file descriptor copy failed \( error) " )
433- if !source. isCancelled {
434- source. cancel ( )
435- if shutdown ( destinationFd, Int32 ( SHUT_RDWR) ) != 0 {
436- log? . warning (
437- " failed to shut down destination after I/O error " ,
438- metadata: [
439- " errno " : " \( errno) " ,
440- " sourceFd " : " \( sourceFd) " ,
441- " dstFd " : " \( destinationFd) " ,
442- ]
443- )
444- }
445- }
446- }
447- }
448-
449- private static func fileDescriptorCopy(
450- buffer: UnsafeMutableBufferPointer < UInt8 > ,
451- size: UInt ,
452- from sourceFd: Int32 ,
453- to destinationFd: Int32
454- ) throws {
455- let bufferSize = buffer. count
456- var readBytesRemaining = min ( Int ( size) , bufferSize)
457-
458- guard let baseAddr = buffer. baseAddress else {
459- throw ContainerizationError (
460- . invalidState,
461- message: " buffer has no base address "
462- )
463- }
464-
465- while readBytesRemaining > 0 {
466- let readResult = read ( sourceFd, baseAddr, min ( bufferSize, readBytesRemaining) )
467- if readResult <= 0 {
468- throw ContainerizationError (
469- . internalError,
470- message: " missing pointer base address "
471- )
472- }
473- readBytesRemaining -= readResult
474-
475- var writeBytesRemaining = readResult
476- while writeBytesRemaining > 0 {
477- let writeResult = write ( destinationFd, baseAddr, writeBytesRemaining)
478- if writeResult <= 0 {
479- throw ContainerizationError (
480- . internalError,
481- message: " zero byte write or error in socket relay: fd \( destinationFd) , result \( writeResult) "
482- )
483- }
484- writeBytesRemaining -= writeResult
485- }
486- }
293+ relay. start ( )
487294 }
488295}
0 commit comments