2020
2121use std:: collections:: HashMap ;
2222use std:: io;
23+ use std:: os:: fd:: RawFd ;
2324use std:: sync:: Arc ;
2425use std:: sync:: atomic:: { AtomicU32 , Ordering } ;
2526use std:: time:: Duration ;
@@ -172,6 +173,9 @@ enum ConnectionState {
172173struct Shared {
173174 /// Serialises writes to the stream.
174175 writer : tokio:: sync:: Mutex < tokio:: net:: unix:: OwnedWriteHalf > ,
176+ /// Raw fd of the underlying socket, used to poison a connection after an
177+ /// interrupted frame write. Ownership remains with the split stream halves.
178+ fd : RawFd ,
175179 /// Monotonically increasing sequence number (starts at 2, skips 0).
176180 /// Handshake uses seq=1 before Shared is created, so post-handshake
177181 /// sequences start at 2 to avoid collisions.
@@ -310,6 +314,30 @@ impl Drop for PendingBoundedOutputGuard {
310314 }
311315}
312316
317+ struct FrameWriteGuard {
318+ shared : Option < Arc < Shared > > ,
319+ }
320+
321+ impl FrameWriteGuard {
322+ fn new ( shared : Arc < Shared > ) -> Self {
323+ Self {
324+ shared : Some ( shared) ,
325+ }
326+ }
327+
328+ fn disarm ( & mut self ) {
329+ self . shared = None ;
330+ }
331+ }
332+
333+ impl Drop for FrameWriteGuard {
334+ fn drop ( & mut self ) {
335+ if let Some ( shared) = self . shared . take ( ) {
336+ shared. poison_connection ( ) ;
337+ }
338+ }
339+ }
340+
313341impl Shared {
314342 /// Get next sequence number, skipping 0 (reserved for unsolicited messages).
315343 fn next_seq ( & self ) -> u32 {
@@ -372,6 +400,11 @@ impl Shared {
372400 }
373401 }
374402
403+ fn poison_connection ( & self ) {
404+ let _ = nix:: sys:: socket:: shutdown ( self . fd , nix:: sys:: socket:: Shutdown :: Both ) ;
405+ self . close ( ) ;
406+ }
407+
375408 fn remove_pending ( & self , seq : u32 ) {
376409 let mut guard = self . state . lock ( ) . unwrap_or_else ( |e| e. into_inner ( ) ) ;
377410 if let ConnectionState :: Connected { pending, .. } = & mut * guard {
@@ -686,7 +719,7 @@ async fn request_raw_on_shared(
686719
687720 // The guard removes the pending entry on write failure, timeout, or
688721 // cancellation before reader_loop dispatches a response.
689- shared . writer . lock ( ) . await . write_all ( & data) . await ?;
722+ write_frame_on_shared ( shared , & data) . await ?;
690723
691724 // `rx` returns `Ok(msg)` when the reader dispatches a response and
692725 // `Err(RecvError)` when `close()` drops the `Connected` variant. The
@@ -705,6 +738,26 @@ async fn request_raw_on_shared(
705738 }
706739}
707740
741+ async fn write_frame_on_shared ( shared : & Arc < Shared > , data : & [ u8 ] ) -> io:: Result < ( ) > {
742+ let mut writer = shared. writer . lock ( ) . await ;
743+ {
744+ let guard = shared. state . lock ( ) . unwrap_or_else ( |e| e. into_inner ( ) ) ;
745+ if matches ! ( & * guard, ConnectionState :: Closed { .. } ) {
746+ return Err ( io:: Error :: new (
747+ io:: ErrorKind :: ConnectionReset ,
748+ "connection closed" ,
749+ ) ) ;
750+ }
751+ }
752+
753+ // Declare after `writer` so cancellation drops the guard before the writer
754+ // lock, preventing another request from writing before the poison close.
755+ let mut write_guard = FrameWriteGuard :: new ( Arc :: clone ( shared) ) ;
756+ writer. write_all ( data) . await ?;
757+ write_guard. disarm ( ) ;
758+ Ok ( ( ) )
759+ }
760+
708761async fn exec_on_shared (
709762 shared : & Arc < Shared > ,
710763 command : & str ,
@@ -835,6 +888,7 @@ impl VsockHost {
835888
836889 let shared = Arc :: new ( Shared {
837890 writer : tokio:: sync:: Mutex :: new ( write_half) ,
891+ fd,
838892 seq : AtomicU32 :: new ( 2 ) ,
839893 state : std:: sync:: Mutex :: new ( ConnectionState :: Connected {
840894 pending : HashMap :: new ( ) ,
@@ -1030,7 +1084,7 @@ impl VsockHost {
10301084 . then ( || PendingBoundedOutputGuard :: new ( Arc :: clone ( & self . shared ) , seq) )
10311085 } ) ;
10321086
1033- self . shared . writer . lock ( ) . await . write_all ( & data) . await ?;
1087+ write_frame_on_shared ( & self . shared , & data) . await ?;
10341088
10351089 let timeout = Duration :: from_millis ( request. timeout_ms as u64 + 5000 ) ;
10361090 let resp = tokio:: select! {
@@ -1376,12 +1430,44 @@ impl VsockHost {
13761430#[ cfg( test) ]
13771431mod tests {
13781432 use super :: * ;
1433+ use std:: future:: Future ;
1434+ use std:: os:: fd:: AsRawFd ;
1435+ use std:: pin:: Pin ;
1436+ use std:: task:: { Context , Poll , Wake , Waker } ;
13791437 use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
13801438
1439+ struct NoopWake ;
1440+
1441+ impl Wake for NoopWake {
1442+ fn wake ( self : std:: sync:: Arc < Self > ) { }
1443+ }
1444+
1445+ fn noop_waker ( ) -> Waker {
1446+ Waker :: from ( std:: sync:: Arc :: new ( NoopWake ) )
1447+ }
1448+
13811449 fn make_pair ( ) -> ( UnixStream , UnixStream ) {
13821450 UnixStream :: pair ( ) . unwrap ( )
13831451 }
13841452
1453+ fn set_send_buffer ( stream : & UnixStream , size : nix:: libc:: c_int ) -> io:: Result < ( ) > {
1454+ // SAFETY: setsockopt receives a valid socket fd and a pointer to a
1455+ // properly sized integer option value for the duration of the call.
1456+ let ret = unsafe {
1457+ nix:: libc:: setsockopt (
1458+ stream. as_raw_fd ( ) ,
1459+ nix:: libc:: SOL_SOCKET ,
1460+ nix:: libc:: SO_SNDBUF ,
1461+ ( & size as * const nix:: libc:: c_int ) . cast ( ) ,
1462+ std:: mem:: size_of_val ( & size) as nix:: libc:: socklen_t ,
1463+ )
1464+ } ;
1465+ if ret < 0 {
1466+ return Err ( io:: Error :: last_os_error ( ) ) ;
1467+ }
1468+ Ok ( ( ) )
1469+ }
1470+
13851471 /// Perform mock guest handshake: send ready, receive ping, send pong.
13861472 async fn mock_handshake ( stream : & mut UnixStream , decoder : & mut Decoder ) {
13871473 // Send ready
@@ -2015,6 +2101,187 @@ mod tests {
20152101 release_guest. notify_one ( ) ;
20162102 }
20172103
2104+ #[ tokio:: test]
2105+ async fn test_cancel_while_waiting_for_writer_lock_does_not_close_connection ( ) {
2106+ let ( host_stream, mut guest) = make_pair ( ) ;
2107+
2108+ let guest_task = tokio:: spawn ( async move {
2109+ let mut decoder = Decoder :: new ( ) ;
2110+ mock_handshake ( & mut guest, & mut decoder) . await ;
2111+
2112+ let mut buf = [ 0u8 ; 4096 ] ;
2113+ let n = guest. read ( & mut buf) . await . unwrap ( ) ;
2114+ let msgs = decoder. decode ( & buf[ ..n] ) . unwrap ( ) ;
2115+ assert_eq ! ( msgs. len( ) , 1 ) ;
2116+ assert_eq ! ( msgs[ 0 ] . msg_type, MSG_EXEC ) ;
2117+ let decoded = vsock_proto:: decode_exec ( & msgs[ 0 ] . payload ) . unwrap ( ) ;
2118+ assert_eq ! ( decoded. command, "after-cancel" ) ;
2119+
2120+ let payload = vsock_proto:: encode_exec_result ( 0 , b"ok" , b"" ) ;
2121+ let resp = vsock_proto:: encode ( MSG_EXEC_RESULT , msgs[ 0 ] . seq , & payload) . unwrap ( ) ;
2122+ guest. write_all ( & resp) . await . unwrap ( ) ;
2123+ } ) ;
2124+
2125+ let host = std:: sync:: Arc :: new ( host_from_stream ( host_stream) . await . unwrap ( ) ) ;
2126+ let writer_guard = host. shared . writer . lock ( ) . await ;
2127+
2128+ let request_host = std:: sync:: Arc :: clone ( & host) ;
2129+ let mut request =
2130+ Box :: pin ( async move { request_host. exec ( "blocked-on-lock" , 5000 , & [ ] , false ) . await } ) ;
2131+ let waker = noop_waker ( ) ;
2132+ let mut cx = Context :: from_waker ( & waker) ;
2133+ assert ! ( matches!(
2134+ Future :: poll( Pin :: as_mut( & mut request) , & mut cx) ,
2135+ Poll :: Pending
2136+ ) ) ;
2137+ assert_eq ! ( registration_counts( & host) , ( 1 , 0 , 0 , 0 ) ) ;
2138+ drop ( request) ;
2139+ assert_eq ! ( registration_counts( & host) , ( 0 , 0 , 0 , 0 ) ) ;
2140+
2141+ drop ( writer_guard) ;
2142+
2143+ let result = host. exec ( "after-cancel" , 5000 , & [ ] , false ) . await . unwrap ( ) ;
2144+ assert_eq ! ( result. exit_code, 0 ) ;
2145+ assert_eq ! ( result. stdout, b"ok" ) ;
2146+ guest_task. await . unwrap ( ) ;
2147+ }
2148+
2149+ #[ tokio:: test]
2150+ async fn test_cancel_during_frame_write_closes_connection ( ) {
2151+ let ( host_stream, mut guest) = make_pair ( ) ;
2152+ set_send_buffer ( & host_stream, 4096 ) . unwrap ( ) ;
2153+
2154+ let frame_started = std:: sync:: Arc :: new ( Notify :: new ( ) ) ;
2155+ let release_guest = std:: sync:: Arc :: new ( Notify :: new ( ) ) ;
2156+
2157+ let guest_task = {
2158+ let frame_started = std:: sync:: Arc :: clone ( & frame_started) ;
2159+ let release_guest = std:: sync:: Arc :: clone ( & release_guest) ;
2160+ tokio:: spawn ( async move {
2161+ let mut decoder = Decoder :: new ( ) ;
2162+ mock_handshake ( & mut guest, & mut decoder) . await ;
2163+
2164+ let mut buf = [ 0u8 ; 1024 ] ;
2165+ let mut n = 0usize ;
2166+ while n < vsock_proto:: HEADER_SIZE {
2167+ let read = guest. read ( & mut buf[ n..] ) . await . unwrap ( ) ;
2168+ assert_ne ! ( read, 0 , "connection closed before frame header arrived" ) ;
2169+ n += read;
2170+ }
2171+ let frame_body_len =
2172+ u32:: from_be_bytes ( buf[ ..vsock_proto:: HEADER_SIZE ] . try_into ( ) . unwrap ( ) )
2173+ as usize ;
2174+ assert ! (
2175+ frame_body_len + vsock_proto:: HEADER_SIZE > n,
2176+ "guest should observe only a partial frame before it stops reading" ,
2177+ ) ;
2178+ frame_started. notify_one ( ) ;
2179+
2180+ release_guest. notified ( ) . await ;
2181+ } )
2182+ } ;
2183+
2184+ let host = std:: sync:: Arc :: new ( host_from_stream ( host_stream) . await . unwrap ( ) ) ;
2185+ let task_host = std:: sync:: Arc :: clone ( & host) ;
2186+ let task = tokio:: spawn ( async move {
2187+ let content = vec ! [ b'x' ; 8 * 1024 * 1024 ] ;
2188+ task_host
2189+ . write_file ( "/tmp/large-frame.bin" , & content, false )
2190+ . await
2191+ } ) ;
2192+
2193+ tokio:: time:: timeout ( Duration :: from_secs ( 5 ) , frame_started. notified ( ) )
2194+ . await
2195+ . expect ( "guest should receive the beginning of the large frame" ) ;
2196+
2197+ task. abort ( ) ;
2198+ let _ = task. await ;
2199+
2200+ host. wait_until_closed ( Duration :: from_secs ( 5 ) )
2201+ . await
2202+ . unwrap ( ) ;
2203+ assert_eq ! ( registration_counts( & host) , ( 0 , 0 , 0 , 0 ) ) ;
2204+
2205+ let err = host
2206+ . exec ( "after-cancelled-write" , 5000 , & [ ] , false )
2207+ . await
2208+ . unwrap_err ( ) ;
2209+ assert_eq ! ( err. kind( ) , io:: ErrorKind :: ConnectionReset ) ;
2210+
2211+ release_guest. notify_one ( ) ;
2212+ guest_task. await . unwrap ( ) ;
2213+ }
2214+
2215+ #[ tokio:: test]
2216+ async fn test_cancel_during_bounded_exec_frame_write_cleans_up_registrations ( ) {
2217+ let ( host_stream, mut guest) = make_pair ( ) ;
2218+ set_send_buffer ( & host_stream, 4096 ) . unwrap ( ) ;
2219+
2220+ let frame_started = std:: sync:: Arc :: new ( Notify :: new ( ) ) ;
2221+ let release_guest = std:: sync:: Arc :: new ( Notify :: new ( ) ) ;
2222+
2223+ let guest_task = {
2224+ let frame_started = std:: sync:: Arc :: clone ( & frame_started) ;
2225+ let release_guest = std:: sync:: Arc :: clone ( & release_guest) ;
2226+ tokio:: spawn ( async move {
2227+ let mut decoder = Decoder :: new ( ) ;
2228+ mock_handshake ( & mut guest, & mut decoder) . await ;
2229+
2230+ let mut buf = [ 0u8 ; 1024 ] ;
2231+ let mut n = 0usize ;
2232+ while n < vsock_proto:: HEADER_SIZE {
2233+ let read = guest. read ( & mut buf[ n..] ) . await . unwrap ( ) ;
2234+ assert_ne ! ( read, 0 , "connection closed before frame header arrived" ) ;
2235+ n += read;
2236+ }
2237+ let frame_body_len =
2238+ u32:: from_be_bytes ( buf[ ..vsock_proto:: HEADER_SIZE ] . try_into ( ) . unwrap ( ) )
2239+ as usize ;
2240+ assert ! (
2241+ frame_body_len + vsock_proto:: HEADER_SIZE > n,
2242+ "guest should observe only a partial frame before it stops reading" ,
2243+ ) ;
2244+ frame_started. notify_one ( ) ;
2245+
2246+ release_guest. notified ( ) . await ;
2247+ } )
2248+ } ;
2249+
2250+ let host = std:: sync:: Arc :: new ( host_from_stream ( host_stream) . await . unwrap ( ) ) ;
2251+ let task_host = std:: sync:: Arc :: clone ( & host) ;
2252+ let task = tokio:: spawn ( async move {
2253+ let ( tx, _rx) = mpsc:: unbounded_channel ( ) ;
2254+ let stdin = vec ! [ b'x' ; 8 * 1024 * 1024 ] ;
2255+ let request = BoundedExecRequest {
2256+ command : "large-stdin" ,
2257+ timeout_ms : 5000 ,
2258+ env : & [ ] ,
2259+ sudo : false ,
2260+ stdin : Some ( & stdin) ,
2261+ stdout_limit_bytes : 1024 ,
2262+ stderr_limit_bytes : 1024 ,
2263+ stream : Some ( bounded_stream_request ( tx) ) ,
2264+ } ;
2265+ task_host. bounded_exec ( & request) . await
2266+ } ) ;
2267+
2268+ tokio:: time:: timeout ( Duration :: from_secs ( 5 ) , frame_started. notified ( ) )
2269+ . await
2270+ . expect ( "guest should receive the beginning of the bounded exec frame" ) ;
2271+
2272+ assert_eq ! ( registration_counts( & host) , ( 1 , 0 , 0 , 1 ) ) ;
2273+ task. abort ( ) ;
2274+ let _ = task. await ;
2275+
2276+ host. wait_until_closed ( Duration :: from_secs ( 5 ) )
2277+ . await
2278+ . unwrap ( ) ;
2279+ assert_eq ! ( registration_counts( & host) , ( 0 , 0 , 0 , 0 ) ) ;
2280+
2281+ release_guest. notify_one ( ) ;
2282+ guest_task. await . unwrap ( ) ;
2283+ }
2284+
20182285 #[ tokio:: test]
20192286 async fn test_bounded_exec_connection_close_cleans_up_registrations ( ) {
20202287 let ( host_stream, mut guest) = make_pair ( ) ;
0 commit comments