Skip to content

Commit 1860100

Browse files
authored
fix(vsock-host): poison interrupted frame writes (#12247)
* fix(vsock-host): poison interrupted frame writes * test(vsock-host): stabilize partial frame cancellation check * test(vsock-host): await cancellation test guests * test(vsock-host): avoid polling delay in cancellation test * test(vsock-host): cover bounded exec write cancellation * docs(vsock-host): explain frame write guard ordering
1 parent dc0f917 commit 1860100

1 file changed

Lines changed: 269 additions & 2 deletions

File tree

crates/vsock-host/src/lib.rs

Lines changed: 269 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
use std::collections::HashMap;
2222
use std::io;
23+
use std::os::fd::RawFd;
2324
use std::sync::Arc;
2425
use std::sync::atomic::{AtomicU32, Ordering};
2526
use std::time::Duration;
@@ -172,6 +173,9 @@ enum ConnectionState {
172173
struct Shared {
173174
/// Serialises writes to the stream.
174175
writer: tokio::sync::Mutex<tokio::net::unix::OwnedWriteHalf>,
176+
/// Raw fd of the underlying socket, used to poison a connection after an
177+
/// interrupted frame write. Ownership remains with the split stream halves.
178+
fd: RawFd,
175179
/// Monotonically increasing sequence number (starts at 2, skips 0).
176180
/// Handshake uses seq=1 before Shared is created, so post-handshake
177181
/// sequences start at 2 to avoid collisions.
@@ -310,6 +314,30 @@ impl Drop for PendingBoundedOutputGuard {
310314
}
311315
}
312316

317+
struct FrameWriteGuard {
318+
shared: Option<Arc<Shared>>,
319+
}
320+
321+
impl FrameWriteGuard {
322+
fn new(shared: Arc<Shared>) -> Self {
323+
Self {
324+
shared: Some(shared),
325+
}
326+
}
327+
328+
fn disarm(&mut self) {
329+
self.shared = None;
330+
}
331+
}
332+
333+
impl Drop for FrameWriteGuard {
334+
fn drop(&mut self) {
335+
if let Some(shared) = self.shared.take() {
336+
shared.poison_connection();
337+
}
338+
}
339+
}
340+
313341
impl Shared {
314342
/// Get next sequence number, skipping 0 (reserved for unsolicited messages).
315343
fn next_seq(&self) -> u32 {
@@ -372,6 +400,11 @@ impl Shared {
372400
}
373401
}
374402

403+
fn poison_connection(&self) {
404+
let _ = nix::sys::socket::shutdown(self.fd, nix::sys::socket::Shutdown::Both);
405+
self.close();
406+
}
407+
375408
fn remove_pending(&self, seq: u32) {
376409
let mut guard = self.state.lock().unwrap_or_else(|e| e.into_inner());
377410
if let ConnectionState::Connected { pending, .. } = &mut *guard {
@@ -686,7 +719,7 @@ async fn request_raw_on_shared(
686719

687720
// The guard removes the pending entry on write failure, timeout, or
688721
// cancellation before reader_loop dispatches a response.
689-
shared.writer.lock().await.write_all(&data).await?;
722+
write_frame_on_shared(shared, &data).await?;
690723

691724
// `rx` returns `Ok(msg)` when the reader dispatches a response and
692725
// `Err(RecvError)` when `close()` drops the `Connected` variant. The
@@ -705,6 +738,26 @@ async fn request_raw_on_shared(
705738
}
706739
}
707740

741+
async fn write_frame_on_shared(shared: &Arc<Shared>, data: &[u8]) -> io::Result<()> {
742+
let mut writer = shared.writer.lock().await;
743+
{
744+
let guard = shared.state.lock().unwrap_or_else(|e| e.into_inner());
745+
if matches!(&*guard, ConnectionState::Closed { .. }) {
746+
return Err(io::Error::new(
747+
io::ErrorKind::ConnectionReset,
748+
"connection closed",
749+
));
750+
}
751+
}
752+
753+
// Declare after `writer` so cancellation drops the guard before the writer
754+
// lock, preventing another request from writing before the poison close.
755+
let mut write_guard = FrameWriteGuard::new(Arc::clone(shared));
756+
writer.write_all(data).await?;
757+
write_guard.disarm();
758+
Ok(())
759+
}
760+
708761
async fn exec_on_shared(
709762
shared: &Arc<Shared>,
710763
command: &str,
@@ -835,6 +888,7 @@ impl VsockHost {
835888

836889
let shared = Arc::new(Shared {
837890
writer: tokio::sync::Mutex::new(write_half),
891+
fd,
838892
seq: AtomicU32::new(2),
839893
state: std::sync::Mutex::new(ConnectionState::Connected {
840894
pending: HashMap::new(),
@@ -1030,7 +1084,7 @@ impl VsockHost {
10301084
.then(|| PendingBoundedOutputGuard::new(Arc::clone(&self.shared), seq))
10311085
});
10321086

1033-
self.shared.writer.lock().await.write_all(&data).await?;
1087+
write_frame_on_shared(&self.shared, &data).await?;
10341088

10351089
let timeout = Duration::from_millis(request.timeout_ms as u64 + 5000);
10361090
let resp = tokio::select! {
@@ -1376,12 +1430,44 @@ impl VsockHost {
13761430
#[cfg(test)]
13771431
mod tests {
13781432
use super::*;
1433+
use std::future::Future;
1434+
use std::os::fd::AsRawFd;
1435+
use std::pin::Pin;
1436+
use std::task::{Context, Poll, Wake, Waker};
13791437
use tokio::io::{AsyncReadExt, AsyncWriteExt};
13801438

1439+
struct NoopWake;
1440+
1441+
impl Wake for NoopWake {
1442+
fn wake(self: std::sync::Arc<Self>) {}
1443+
}
1444+
1445+
fn noop_waker() -> Waker {
1446+
Waker::from(std::sync::Arc::new(NoopWake))
1447+
}
1448+
13811449
fn make_pair() -> (UnixStream, UnixStream) {
13821450
UnixStream::pair().unwrap()
13831451
}
13841452

1453+
fn set_send_buffer(stream: &UnixStream, size: nix::libc::c_int) -> io::Result<()> {
1454+
// SAFETY: setsockopt receives a valid socket fd and a pointer to a
1455+
// properly sized integer option value for the duration of the call.
1456+
let ret = unsafe {
1457+
nix::libc::setsockopt(
1458+
stream.as_raw_fd(),
1459+
nix::libc::SOL_SOCKET,
1460+
nix::libc::SO_SNDBUF,
1461+
(&size as *const nix::libc::c_int).cast(),
1462+
std::mem::size_of_val(&size) as nix::libc::socklen_t,
1463+
)
1464+
};
1465+
if ret < 0 {
1466+
return Err(io::Error::last_os_error());
1467+
}
1468+
Ok(())
1469+
}
1470+
13851471
/// Perform mock guest handshake: send ready, receive ping, send pong.
13861472
async fn mock_handshake(stream: &mut UnixStream, decoder: &mut Decoder) {
13871473
// Send ready
@@ -2015,6 +2101,187 @@ mod tests {
20152101
release_guest.notify_one();
20162102
}
20172103

2104+
#[tokio::test]
2105+
async fn test_cancel_while_waiting_for_writer_lock_does_not_close_connection() {
2106+
let (host_stream, mut guest) = make_pair();
2107+
2108+
let guest_task = tokio::spawn(async move {
2109+
let mut decoder = Decoder::new();
2110+
mock_handshake(&mut guest, &mut decoder).await;
2111+
2112+
let mut buf = [0u8; 4096];
2113+
let n = guest.read(&mut buf).await.unwrap();
2114+
let msgs = decoder.decode(&buf[..n]).unwrap();
2115+
assert_eq!(msgs.len(), 1);
2116+
assert_eq!(msgs[0].msg_type, MSG_EXEC);
2117+
let decoded = vsock_proto::decode_exec(&msgs[0].payload).unwrap();
2118+
assert_eq!(decoded.command, "after-cancel");
2119+
2120+
let payload = vsock_proto::encode_exec_result(0, b"ok", b"");
2121+
let resp = vsock_proto::encode(MSG_EXEC_RESULT, msgs[0].seq, &payload).unwrap();
2122+
guest.write_all(&resp).await.unwrap();
2123+
});
2124+
2125+
let host = std::sync::Arc::new(host_from_stream(host_stream).await.unwrap());
2126+
let writer_guard = host.shared.writer.lock().await;
2127+
2128+
let request_host = std::sync::Arc::clone(&host);
2129+
let mut request =
2130+
Box::pin(async move { request_host.exec("blocked-on-lock", 5000, &[], false).await });
2131+
let waker = noop_waker();
2132+
let mut cx = Context::from_waker(&waker);
2133+
assert!(matches!(
2134+
Future::poll(Pin::as_mut(&mut request), &mut cx),
2135+
Poll::Pending
2136+
));
2137+
assert_eq!(registration_counts(&host), (1, 0, 0, 0));
2138+
drop(request);
2139+
assert_eq!(registration_counts(&host), (0, 0, 0, 0));
2140+
2141+
drop(writer_guard);
2142+
2143+
let result = host.exec("after-cancel", 5000, &[], false).await.unwrap();
2144+
assert_eq!(result.exit_code, 0);
2145+
assert_eq!(result.stdout, b"ok");
2146+
guest_task.await.unwrap();
2147+
}
2148+
2149+
#[tokio::test]
2150+
async fn test_cancel_during_frame_write_closes_connection() {
2151+
let (host_stream, mut guest) = make_pair();
2152+
set_send_buffer(&host_stream, 4096).unwrap();
2153+
2154+
let frame_started = std::sync::Arc::new(Notify::new());
2155+
let release_guest = std::sync::Arc::new(Notify::new());
2156+
2157+
let guest_task = {
2158+
let frame_started = std::sync::Arc::clone(&frame_started);
2159+
let release_guest = std::sync::Arc::clone(&release_guest);
2160+
tokio::spawn(async move {
2161+
let mut decoder = Decoder::new();
2162+
mock_handshake(&mut guest, &mut decoder).await;
2163+
2164+
let mut buf = [0u8; 1024];
2165+
let mut n = 0usize;
2166+
while n < vsock_proto::HEADER_SIZE {
2167+
let read = guest.read(&mut buf[n..]).await.unwrap();
2168+
assert_ne!(read, 0, "connection closed before frame header arrived");
2169+
n += read;
2170+
}
2171+
let frame_body_len =
2172+
u32::from_be_bytes(buf[..vsock_proto::HEADER_SIZE].try_into().unwrap())
2173+
as usize;
2174+
assert!(
2175+
frame_body_len + vsock_proto::HEADER_SIZE > n,
2176+
"guest should observe only a partial frame before it stops reading",
2177+
);
2178+
frame_started.notify_one();
2179+
2180+
release_guest.notified().await;
2181+
})
2182+
};
2183+
2184+
let host = std::sync::Arc::new(host_from_stream(host_stream).await.unwrap());
2185+
let task_host = std::sync::Arc::clone(&host);
2186+
let task = tokio::spawn(async move {
2187+
let content = vec![b'x'; 8 * 1024 * 1024];
2188+
task_host
2189+
.write_file("/tmp/large-frame.bin", &content, false)
2190+
.await
2191+
});
2192+
2193+
tokio::time::timeout(Duration::from_secs(5), frame_started.notified())
2194+
.await
2195+
.expect("guest should receive the beginning of the large frame");
2196+
2197+
task.abort();
2198+
let _ = task.await;
2199+
2200+
host.wait_until_closed(Duration::from_secs(5))
2201+
.await
2202+
.unwrap();
2203+
assert_eq!(registration_counts(&host), (0, 0, 0, 0));
2204+
2205+
let err = host
2206+
.exec("after-cancelled-write", 5000, &[], false)
2207+
.await
2208+
.unwrap_err();
2209+
assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
2210+
2211+
release_guest.notify_one();
2212+
guest_task.await.unwrap();
2213+
}
2214+
2215+
#[tokio::test]
2216+
async fn test_cancel_during_bounded_exec_frame_write_cleans_up_registrations() {
2217+
let (host_stream, mut guest) = make_pair();
2218+
set_send_buffer(&host_stream, 4096).unwrap();
2219+
2220+
let frame_started = std::sync::Arc::new(Notify::new());
2221+
let release_guest = std::sync::Arc::new(Notify::new());
2222+
2223+
let guest_task = {
2224+
let frame_started = std::sync::Arc::clone(&frame_started);
2225+
let release_guest = std::sync::Arc::clone(&release_guest);
2226+
tokio::spawn(async move {
2227+
let mut decoder = Decoder::new();
2228+
mock_handshake(&mut guest, &mut decoder).await;
2229+
2230+
let mut buf = [0u8; 1024];
2231+
let mut n = 0usize;
2232+
while n < vsock_proto::HEADER_SIZE {
2233+
let read = guest.read(&mut buf[n..]).await.unwrap();
2234+
assert_ne!(read, 0, "connection closed before frame header arrived");
2235+
n += read;
2236+
}
2237+
let frame_body_len =
2238+
u32::from_be_bytes(buf[..vsock_proto::HEADER_SIZE].try_into().unwrap())
2239+
as usize;
2240+
assert!(
2241+
frame_body_len + vsock_proto::HEADER_SIZE > n,
2242+
"guest should observe only a partial frame before it stops reading",
2243+
);
2244+
frame_started.notify_one();
2245+
2246+
release_guest.notified().await;
2247+
})
2248+
};
2249+
2250+
let host = std::sync::Arc::new(host_from_stream(host_stream).await.unwrap());
2251+
let task_host = std::sync::Arc::clone(&host);
2252+
let task = tokio::spawn(async move {
2253+
let (tx, _rx) = mpsc::unbounded_channel();
2254+
let stdin = vec![b'x'; 8 * 1024 * 1024];
2255+
let request = BoundedExecRequest {
2256+
command: "large-stdin",
2257+
timeout_ms: 5000,
2258+
env: &[],
2259+
sudo: false,
2260+
stdin: Some(&stdin),
2261+
stdout_limit_bytes: 1024,
2262+
stderr_limit_bytes: 1024,
2263+
stream: Some(bounded_stream_request(tx)),
2264+
};
2265+
task_host.bounded_exec(&request).await
2266+
});
2267+
2268+
tokio::time::timeout(Duration::from_secs(5), frame_started.notified())
2269+
.await
2270+
.expect("guest should receive the beginning of the bounded exec frame");
2271+
2272+
assert_eq!(registration_counts(&host), (1, 0, 0, 1));
2273+
task.abort();
2274+
let _ = task.await;
2275+
2276+
host.wait_until_closed(Duration::from_secs(5))
2277+
.await
2278+
.unwrap();
2279+
assert_eq!(registration_counts(&host), (0, 0, 0, 0));
2280+
2281+
release_guest.notify_one();
2282+
guest_task.await.unwrap();
2283+
}
2284+
20182285
#[tokio::test]
20192286
async fn test_bounded_exec_connection_close_cleans_up_registrations() {
20202287
let (host_stream, mut guest) = make_pair();

0 commit comments

Comments
 (0)