diff --git a/CHANGELOG.md b/CHANGELOG.md index fdc35c74a8d..4125f23ba2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to ### Added +- [#5595](https://github.com/firecracker-microvm/firecracker/pull/5595): Added + `vsock_type` field to the vsock device API to denote the type of the + underlying socket. Can be `stream` or `seqpacket` +- [#5595](https://github.com/firecracker-microvm/firecracker/pull/5595): Added + `conn_buffer_size` field to denote how many bytes we can internally buffer + during receiving large seqpacket packets from the host. - [#5323](https://github.com/firecracker-microvm/firecracker/pull/5323): Add support for Vsock Unix domain socket path overriding on snapshot restore. More information can be found in the diff --git a/resources/seccomp/aarch64-unknown-linux-musl.json b/resources/seccomp/aarch64-unknown-linux-musl.json index 0d8d4e51224..d9bcf90c11e 100644 --- a/resources/seccomp/aarch64-unknown-linux-musl.json +++ b/resources/seccomp/aarch64-unknown-linux-musl.json @@ -323,6 +323,32 @@ } ] }, + { + "syscall": "socket", + "comment": "Called to open the vsock seqpacket UDS (SeqpacketConn::connect and SeqpacketListener::bind)", + "args": [ + { + "index": 0, + "type": "dword", + "op": "eq", + "val": 1, + "comment": "libc::AF_UNIX" + }, + { + "index": 1, + "type": "dword", + "op": "eq", + "val": 524293, + "comment": "libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC" + }, + { + "index": 2, + "type": "dword", + "op": "eq", + "val": 0 + } + ] + }, { "syscall": "sendto", "comment": "Rust std uses it to write to unix socket" diff --git a/resources/seccomp/x86_64-unknown-linux-musl.json b/resources/seccomp/x86_64-unknown-linux-musl.json index 4ccbfbd8e50..e4b06826f39 100644 --- a/resources/seccomp/x86_64-unknown-linux-musl.json +++ b/resources/seccomp/x86_64-unknown-linux-musl.json @@ -323,6 +323,32 @@ } ] }, + { + "syscall": "socket", + "comment": "Called to open the vsock seqpacket UDS (SeqpacketConn::connect and SeqpacketListener::bind)", + "args": [ + { + "index": 0, + "type": "dword", + "op": "eq", + "val": 1, + "comment": "libc::AF_UNIX" + }, + { + "index": 1, + "type": "dword", + "op": "eq", + "val": 524293, + "comment": "libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC" + }, + { + "index": 2, + "type": "dword", + "op": "eq", + "val": 0 + } + ] + }, { "syscall": "sendto", "comment": "Rust std uses it to write to unix socket" diff --git a/src/firecracker/src/api_server/parsed_request.rs b/src/firecracker/src/api_server/parsed_request.rs index 31a10057c99..ae110db050d 100644 --- a/src/firecracker/src/api_server/parsed_request.rs +++ b/src/firecracker/src/api_server/parsed_request.rs @@ -946,7 +946,7 @@ pub mod tests { fn test_try_from_put_vsock() { let (mut sender, receiver) = UnixStream::pair().unwrap(); let mut connection = HttpConnection::new(receiver); - let body = "{ \"vsock_id\": \"string\", \"guest_cid\": 0, \"uds_path\": \"string\" }"; + let body = "{ \"vsock_id\": \"string\", \"guest_cid\": 0, \"uds_path\": \"string\", \"vsock_type\": \"stream\" }"; sender .write_all(http_request("PUT", "/vsock", Some(body)).as_bytes()) .unwrap(); diff --git a/src/firecracker/src/api_server/request/vsock.rs b/src/firecracker/src/api_server/request/vsock.rs index acf129d456c..66662c3e18e 100644 --- a/src/firecracker/src/api_server/request/vsock.rs +++ b/src/firecracker/src/api_server/request/vsock.rs @@ -41,7 +41,8 @@ mod tests { fn test_parse_put_vsock_request() { let body = r#"{ "guest_cid": 42, - "uds_path": "vsock.sock" + "uds_path": "vsock.sock", + "vsock_type": "stream" }"#; parse_put_vsock(&Body::new(body)).unwrap(); @@ -57,7 +58,8 @@ mod tests { let body = r#"{ "vsock_id": "foo", "guest_cid": 42, - "uds_path": "vsock.sock" + "uds_path": "vsock.sock", + "vsock_type": "stream" }"#; depr_action_from_req( parse_put_vsock(&Body::new(body)).unwrap(), @@ -66,7 +68,8 @@ mod tests { let body = r#"{ "guest_cid": 42, - "uds_path": "vsock.sock" + "uds_path": "vsock.sock", + "vsock_type": "stream" }"#; let (_, mut parsing_info) = parse_put_vsock(&Body::new(body)).unwrap().into_parts(); assert!(parsing_info.take_deprecation_message().is_none()); diff --git a/src/firecracker/swagger/firecracker.yaml b/src/firecracker/swagger/firecracker.yaml index 5e8f16f732c..639eecc362f 100644 --- a/src/firecracker/swagger/firecracker.yaml +++ b/src/firecracker/swagger/firecracker.yaml @@ -1858,6 +1858,22 @@ definitions: uds_path: type: string description: Path to UNIX domain socket, used to proxy vsock connections. + vsock_type: + description: Enumeration indicating the type of the underlying socket (stream or seqpacket) + type: string + enum: + - stream + - seqpacket + default: stream + conn_buffer_size: + type: integer + minimum: 4096 + maximum: 262144 + description: + The amount in bytes that can be buffered in firecracker if the data in the tx/rx queue is + too much to fit in a single descriptor. This parameter is ignored for stream sockets + because connection buffering is a seqpacket only concept. The minimum is 4096 (one + virtqueue descriptor) and the maximum is 256KB (kernel limit) vsock_id: type: string description: diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index 1f6110583c7..b5e1b22d5eb 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -984,7 +984,7 @@ pub(crate) mod tests { vsock_config: VsockDeviceConfig, ) { let vsock_dev_id = VSOCK_DEV_ID.to_owned(); - let vsock = VsockBuilder::create_unixsock_vsock(vsock_config).unwrap(); + let vsock = VsockBuilder::create_unixsock_vsock(&vsock_config).unwrap(); let vsock = Arc::new(Mutex::new(vsock)); attach_unixsock_vsock_device( diff --git a/src/vmm/src/device_manager/pci_mngr.rs b/src/vmm/src/device_manager/pci_mngr.rs index d2e41b45955..2c40146cfa9 100644 --- a/src/vmm/src/device_manager/pci_mngr.rs +++ b/src/vmm/src/device_manager/pci_mngr.rs @@ -634,7 +634,7 @@ mod tests { use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vmm_config::net::NetworkInterfaceConfig; use crate::vmm_config::pmem::PmemConfig; - use crate::vmm_config::vsock::VsockDeviceConfig; + use crate::vmm_config::vsock::{VsockDeviceConfig, VsockType}; #[test] fn test_device_manager_persistence() { @@ -693,6 +693,8 @@ mod tests { vsock_id: Some(vsock_dev_id.to_string()), guest_cid: 3, uds_path: tmp_sock_file.as_path().to_str().unwrap().to_string(), + vsock_type: VsockType::Stream, + conn_buffer_size: None, }; insert_vsock_device(&mut vmm, &mut cmdline, &mut event_manager, vsock_config); // Add an entropy device. @@ -803,7 +805,8 @@ mod tests { ], "vsock": {{ "guest_cid": 3, - "uds_path": "{}" + "uds_path": "{}", + "vsock_type": "stream" }}, "entropy": {{ "rate_limiter": null diff --git a/src/vmm/src/device_manager/persist.rs b/src/vmm/src/device_manager/persist.rs index e9e741555e2..7f9af695d16 100644 --- a/src/vmm/src/device_manager/persist.rs +++ b/src/vmm/src/device_manager/persist.rs @@ -624,7 +624,7 @@ mod tests { use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vmm_config::net::NetworkInterfaceConfig; use crate::vmm_config::pmem::PmemConfig; - use crate::vmm_config::vsock::VsockDeviceConfig; + use crate::vmm_config::vsock::{VsockDeviceConfig, VsockType}; impl PartialEq for VirtioDeviceState { fn eq(&self, other: &VirtioDeviceState) -> bool { @@ -723,6 +723,8 @@ mod tests { vsock_id: Some(vsock_dev_id.to_string()), guest_cid: 3, uds_path: tmp_sock_file.as_path().to_str().unwrap().to_string(), + vsock_type: VsockType::Stream, + conn_buffer_size: None, }; insert_vsock_device(&mut vmm, &mut cmdline, &mut event_manager, vsock_config); // Add an entropy device. @@ -830,7 +832,8 @@ mod tests { ], "vsock": {{ "guest_cid": 3, - "uds_path": "{}" + "uds_path": "{}", + "vsock_type": "stream" }}, "entropy": {{ "rate_limiter": null diff --git a/src/vmm/src/devices/virtio/generated/virtio_config.rs b/src/vmm/src/devices/virtio/generated/virtio_config.rs index 886bd07ac39..ce0042053fc 100644 --- a/src/vmm/src/devices/virtio/generated/virtio_config.rs +++ b/src/vmm/src/devices/virtio/generated/virtio_config.rs @@ -16,6 +16,7 @@ clippy::redundant_static_lifetimes )] +pub const VIRTIO_VSOCK_F_SEQPACKET: u32 = 1; pub const VIRTIO_F_NOTIFY_ON_EMPTY: u32 = 24; pub const VIRTIO_F_ANY_LAYOUT: u32 = 27; pub const VIRTIO_F_VERSION_1: u32 = 32; diff --git a/src/vmm/src/devices/virtio/persist.rs b/src/vmm/src/devices/virtio/persist.rs index 2a75945f617..4b69c93880a 100644 --- a/src/vmm/src/devices/virtio/persist.rs +++ b/src/vmm/src/devices/virtio/persist.rs @@ -268,6 +268,8 @@ mod tests { use crate::devices::virtio::test_utils::default_mem; use crate::devices::virtio::transport::mmio::tests::DummyDevice; use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend}; + use crate::snapshot::Snapshot; + use crate::vmm_config::vsock::VsockType; const DEFAULT_QUEUE_MAX_SIZE: u16 = 256; impl Default for QueueState { @@ -481,7 +483,7 @@ mod tests { // Remove the file so the path can be used by the socket. temp_uds_path.remove().unwrap(); let uds_path = String::from(temp_uds_path.as_path().to_str().unwrap()); - let backend = VsockUnixBackend::new(guest_cid, uds_path).unwrap(); + let backend = VsockUnixBackend::new(guest_cid, uds_path, VsockType::Stream, None).unwrap(); let vsock = Vsock::new(guest_cid, backend).unwrap(); let vsock = Arc::new(Mutex::new(vsock)); let mmio_transport = diff --git a/src/vmm/src/devices/virtio/vsock/csm/connection.rs b/src/vmm/src/devices/virtio/vsock/csm/connection.rs index eee415ef65f..b32d668d9e5 100644 --- a/src/vmm/src/devices/virtio/vsock/csm/connection.rs +++ b/src/vmm/src/devices/virtio/vsock/csm/connection.rs @@ -77,23 +77,26 @@ use std::fmt::Debug; // 2. The receiver can be proactive, and send VSOCK_OP_CREDIT_UPDATE packet, whenever // it thinks its peer's information is out of date. // Our implementation uses the proactive approach. -use std::io::{ErrorKind, Write}; +use std::io::{Cursor, Error, ErrorKind, Write}; use std::num::Wrapping; use std::os::unix::io::{AsRawFd, RawFd}; use std::time::{Duration, Instant}; -use vm_memory::GuestMemoryError; use vm_memory::io::{ReadVolatile, WriteVolatile}; +use vm_memory::{GuestMemoryError, VolatileMemory, VolatileSlice}; use vmm_sys_util::epoll::EventSet; use super::super::defs::uapi; use super::super::{VsockChannel, VsockEpollListener, VsockError}; use super::txbuf::TxBuf; use super::{ConnState, PendingRx, PendingRxSet, VsockCsmError, defs}; +use crate::devices::virtio::vsock::VsockUnixBackendError; use crate::devices::virtio::vsock::metrics::METRICS; use crate::devices::virtio::vsock::packet::{VsockPacketHeader, VsockPacketRx, VsockPacketTx}; +use crate::devices::virtio::vsock::unix::{IncomingLength, ReadResult}; use crate::logger::{IncMetric, debug, error, info, warn}; use crate::utils::wrap_usize_to_u32; +use crate::vmm_config::vsock::VsockType; /// Trait that vsock connection backends need to implement. /// @@ -102,6 +105,8 @@ use crate::utils::wrap_usize_to_u32; /// ). pub trait VsockConnectionBackend: ReadVolatile + Write + WriteVolatile + AsRawFd {} +const DEFAULT_CONN_BUFFER_SIZE: usize = (64 * 1024); + /// A self-managing connection object, that handles communication between a guest-side AF_VSOCK /// socket and a host-side `ReadVolatile + Write + WriteVolatile + AsRawFd` stream. #[derive(Debug)] @@ -138,6 +143,116 @@ pub struct VsockConnection { /// Instant when this connection should be scheduled for immediate termination, due to some /// timeout condition having been fulfilled. expiry: Option, + /// The type of the underlying socket connection + vsock_type: VsockType, + /// Intermediate buffer for bytes received from the AF_UNIX + connection_buffer: Option>>, + /// The amount of bytes we wrote into the intermediate connection buffer + conn_buf_size: usize, +} + +impl VsockConnection { + fn recv_into( + &mut self, + pkt: &mut VsockPacketRx, + max_len: u32, + ) -> Result { + match self.vsock_type { + VsockType::Stream => { + let stream_bytes_read = pkt.read_at_offset_from(&mut self.stream, 0, max_len)?; + Ok(ReadResult::new(stream_bytes_read, false)) + } + VsockType::Seqpacket => { + if self.connection_buffer.is_none() { + let incoming_msg_size = self.stream.incoming_len().map_err(|e| { + VsockError::VsockUdsBackend(VsockUnixBackendError::UnixRead(e)) + })?; + + if incoming_msg_size > pkt.buf_size() as usize { + self.handle_new_packet_large( + pkt, + max_len, + u32::try_from(incoming_msg_size).unwrap_or(u32::MAX), + ) + } else { + self.handle_new_packet_small(pkt, max_len) + } + } else { + self.handle_connection_buffer_has_data(pkt, max_len) + } + } + } + } + + fn handle_new_packet_large( + &mut self, + pkt: &mut VsockPacketRx, + max_len: u32, + incoming_msg_len: u32, + ) -> Result { + if incoming_msg_len as usize > self.conn_buf_size { + return Err(VsockError::MessageTooLong( + u32::try_from(self.conn_buf_size).unwrap_or(u32::MAX), + incoming_msg_len, + )); + } + + let mut backing_vector = vec![0u8; incoming_msg_len as usize]; + { + // SAFETY: `backing_vector` is a valid Vec and we hold a mutable reference to it, + // guaranteeing exclusive access for the duration of this call. + let mut vol_slice = unsafe { + VolatileSlice::new(backing_vector.as_mut_ptr(), incoming_msg_len as usize) + }; + self.stream + .read_volatile(&mut vol_slice) + .map_err(VsockError::VolatileMemory)?; + } + + let mut cursor = Cursor::new(backing_vector); + let b = pkt.read_at_offset_from(&mut cursor, 0, max_len)?; + self.connection_buffer = Some(cursor); + + Ok(ReadResult::new(b, true)) + } + + fn handle_new_packet_small( + &mut self, + pkt: &mut VsockPacketRx, + max_len: u32, + ) -> Result { + let b = pkt.read_at_offset_from(&mut self.stream, 0, max_len)?; + // packet is small enough to fit into a single descriptor, set EOM/EOR directly. + pkt.hdr.set_msg_eom().set_msg_eor(); + Ok(ReadResult::new(b, false)) + } + + fn handle_connection_buffer_has_data( + &mut self, + pkt: &mut VsockPacketRx, + max_len: u32, + ) -> Result { + let Some(ref mut conn_buf) = self.connection_buffer else { + return Err(VsockError::PktBufMissing); + }; + + let conn_buffer_rem = u32::try_from(conn_buf.get_ref().len() as u64 - conn_buf.position()) + .unwrap_or(u32::MAX); + + let b = pkt.read_at_offset_from(conn_buf, 0, max_len.min(conn_buffer_rem))?; + + // set MSG_EOM/EOR if we finished the buffer and mark should + // retrigger as false. or mark should retrigger to true to + // make another read happen + let done = conn_buf.position() >= conn_buf.get_ref().len() as u64; + if done { + self.connection_buffer = None; + pkt.hdr.set_msg_eom().set_msg_eor(); + Ok(ReadResult::new(b, false)) + } else { + Ok(ReadResult::new(b, true)) + } + } } impl VsockChannel for VsockConnection @@ -159,7 +274,7 @@ where /// - `Err(VsockError::NoData)`: there was no data available with which to fill in the packet; /// - `Err(VsockError::PktBufMissing)`: the packet would've been filled in with data, but it is /// missing the data buffer. - fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> { + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result { // Perform some generic initialization that is the same for any packet operation (e.g. // source, destination, credit, etc). self.init_pkt_hdr(&mut pkt.hdr); @@ -169,7 +284,7 @@ where // It's dead, Jim. if self.pending_rx.remove(PendingRx::Rst) { pkt.hdr.set_op(uapi::VSOCK_OP_RST); - return Ok(()); + return Ok(ReadResult::default()); } // Next up: if we're due a connection confirmation, that's all we need to know to fill @@ -177,7 +292,7 @@ where if self.pending_rx.remove(PendingRx::Response) { self.state = ConnState::Established; pkt.hdr.set_op(uapi::VSOCK_OP_RESPONSE); - return Ok(()); + return Ok(ReadResult::default()); } // Same thing goes for locally-initiated connections that need to yield a connection @@ -186,7 +301,7 @@ where self.expiry = Some(Instant::now() + Duration::from_millis(defs::CONN_REQUEST_TIMEOUT_MS)); pkt.hdr.set_op(uapi::VSOCK_OP_REQUEST); - return Ok(()); + return Ok(ReadResult::default()); } if self.pending_rx.remove(PendingRx::Rw) { @@ -201,7 +316,7 @@ where // Any other connection state is invalid at this point, and we need to kill it // with fire. pkt.hdr.set_op(uapi::VSOCK_OP_RST); - return Ok(()); + return Ok(ReadResult::default()); } } @@ -210,17 +325,17 @@ where if self.need_credit_update_from_peer() { self.last_fwd_cnt_to_peer = self.fwd_cnt; pkt.hdr.set_op(uapi::VSOCK_OP_CREDIT_REQUEST); - return Ok(()); + return Ok(ReadResult::default()); } // The maximum amount of data we can read in is limited by both the RX buffer size and // the peer available buffer space. let max_len = std::cmp::min(pkt.buf_size(), self.peer_avail_credit()); - // Read data from the stream straight to the RX buffer, for maximum throughput. - match pkt.read_at_offset_from(&mut self.stream, 0, max_len) { - Ok(read_cnt) => { - if read_cnt == 0 { + let recv_res = self.recv_into(pkt, max_len); + match recv_res { + Ok(res) => { + if res.bytes_read == 0 { // A 0-length read means the host stream was closed down. In that case, // we'll ask our peer to shut down the connection. We can neither send nor // receive any more data. @@ -237,12 +352,19 @@ where // length of the read data. // Safe to unwrap because read_cnt is no more than max_len, which is bounded // by self.peer_avail_credit(), a u32 internally. - pkt.hdr.set_op(uapi::VSOCK_OP_RW).set_len(read_cnt); - METRICS.rx_bytes_count.add(read_cnt as u64); + pkt.hdr.set_op(uapi::VSOCK_OP_RW).set_len(res.bytes_read); + METRICS.rx_bytes_count.add(res.bytes_read as u64); } self.rx_cnt += Wrapping(pkt.hdr.len()); self.last_fwd_cnt_to_peer = self.fwd_cnt; - return Ok(()); + + // the read was buffered into an intermediate vector and this + // means there is still data to process but no fd event will + // kick off. manually push a a PendingRx queue entry + if res.should_retrigger { + self.pending_rx.insert(PendingRx::Rw); + } + return Ok(res); } Err(VsockError::GuestMemoryMmap(GuestMemoryError::IOError(err))) if err.kind() == ErrorKind::WouldBlock => @@ -265,7 +387,7 @@ where ); pkt.hdr.set_op(uapi::VSOCK_OP_RST); self.last_fwd_cnt_to_peer = self.fwd_cnt; - return Ok(()); + return Ok(ReadResult::default()); } }; } @@ -276,7 +398,7 @@ where if self.pending_rx.remove(PendingRx::CreditUpdate) && !self.has_pending_rx() { pkt.hdr.set_op(uapi::VSOCK_OP_CREDIT_UPDATE); self.last_fwd_cnt_to_peer = self.fwd_cnt; - return Ok(()); + return Ok(ReadResult::default()); } // We've already checked for all conditions that would have produced a packet, so @@ -501,6 +623,7 @@ where S: VsockConnectionBackend + Debug, { /// Create a new guest-initiated connection object. + #[allow(clippy::too_many_arguments)] pub fn new_peer_init( stream: S, local_cid: u64, @@ -508,7 +631,13 @@ where local_port: u32, peer_port: u32, peer_buf_alloc: u32, + vsock_type: VsockType, + conn_buffer_size: Option, ) -> Self { + let buf_size = match vsock_type { + VsockType::Seqpacket => conn_buffer_size.unwrap_or(DEFAULT_CONN_BUFFER_SIZE), + VsockType::Stream => 0, + }; Self { local_cid, peer_cid, @@ -524,6 +653,9 @@ where last_fwd_cnt_to_peer: Wrapping(0), pending_rx: PendingRxSet::from(PendingRx::Response), expiry: None, + vsock_type, + connection_buffer: None, + conn_buf_size: buf_size, } } @@ -534,7 +666,13 @@ where peer_cid: u64, local_port: u32, peer_port: u32, + vsock_type: VsockType, + conn_buffer_size: Option, ) -> Self { + let buf_size = match vsock_type { + VsockType::Seqpacket => conn_buffer_size.unwrap_or(DEFAULT_CONN_BUFFER_SIZE), + VsockType::Stream => 0, + }; Self { local_cid, peer_cid, @@ -550,6 +688,9 @@ where last_fwd_cnt_to_peer: Wrapping(0), pending_rx: PendingRxSet::from(PendingRx::Request), expiry: None, + vsock_type, + connection_buffer: None, + conn_buf_size: buf_size, } } @@ -670,9 +811,12 @@ where .set_dst_cid(self.peer_cid) .set_src_port(self.local_port) .set_dst_port(self.peer_port) - .set_type(uapi::VSOCK_TYPE_STREAM) .set_buf_alloc(defs::CONN_TX_BUF_SIZE) .set_fwd_cnt(self.fwd_cnt.0); + match self.vsock_type { + VsockType::Seqpacket => hdr.set_type(uapi::VSOCK_TYPE_SEQPACKET), + VsockType::Stream => hdr.set_type(uapi::VSOCK_TYPE_STREAM), + }; } } @@ -881,9 +1025,17 @@ mod tests { LOCAL_PORT, PEER_PORT, PEER_BUF_ALLOC, + VsockType::Stream, + None, ), ConnState::LocalInit => VsockConnection::::new_local_init( - stream, LOCAL_CID, PEER_CID, LOCAL_PORT, PEER_PORT, + stream, + LOCAL_CID, + PEER_CID, + LOCAL_PORT, + PEER_PORT, + VsockType::Stream, + None, ), ConnState::Established => { let mut conn = VsockConnection::::new_peer_init( @@ -893,6 +1045,8 @@ mod tests { LOCAL_PORT, PEER_PORT, PEER_BUF_ALLOC, + VsockType::Stream, + None, ); assert!(conn.has_pending_rx()); conn.recv_pkt(&mut rx_pkt).unwrap(); @@ -1281,6 +1435,278 @@ mod tests { } } + // A real AF_UNIX SOCK_SEQPACKET socket pair used for seqpacket tests. + // The local fd is the connection's read end; the remote fd is the test's write end. + #[derive(Debug)] + struct SeqpacketTestStream { + local_fd: RawFd, + remote_fd: RawFd, + } + + impl SeqpacketTestStream { + fn new() -> Self { + let mut fds = [0i32; 2]; + // SAFETY: valid AF_UNIX socketpair call; fds is a valid 2-element array. + let ret = unsafe { + libc::socketpair( + libc::AF_UNIX, + libc::SOCK_SEQPACKET | libc::SOCK_NONBLOCK, + 0, + fds.as_mut_ptr(), + ) + }; + assert_eq!(ret, 0, "socketpair failed: {}", IoError::last_os_error()); + Self { + local_fd: fds[0], + remote_fd: fds[1], + } + } + + // Write one seqpacket message into the remote end. + fn push_message(&self, data: &[u8]) { + // SAFETY: `remote_fd` is valid; `data` is a valid slice for the duration of the call. + let ret = unsafe { + libc::write( + self.remote_fd, + data.as_ptr().cast::(), + data.len(), + ) + }; + assert_eq!(ret.cast_unsigned(), data.len(), "push_message write failed"); + } + } + + impl Drop for SeqpacketTestStream { + fn drop(&mut self) { + // SAFETY: Both fds are valid and owned by this struct; closing them on drop. + unsafe { + libc::close(self.local_fd); + libc::close(self.remote_fd); + } + } + } + + impl AsRawFd for SeqpacketTestStream { + fn as_raw_fd(&self) -> RawFd { + self.local_fd + } + } + + impl ReadVolatile for SeqpacketTestStream { + fn read_volatile( + &mut self, + buf: &mut VolatileSlice, + ) -> Result { + let mut tmp = vec![0u8; buf.len()]; + // SAFETY: `local_fd` is valid; `tmp` is a valid writable buffer for the duration of + // the call. + let ret = unsafe { + libc::recv( + self.local_fd, + tmp.as_mut_ptr().cast::(), + tmp.len(), + 0, + ) + }; + if ret < 0 { + return Err(VolatileMemoryError::IOError(IoError::last_os_error())); + } + let n = ret.cast_unsigned(); + buf.copy_from(&tmp[..n]); + Ok(n) + } + } + + impl Write for SeqpacketTestStream { + fn write(&mut self, data: &[u8]) -> Result { + // SAFETY: `local_fd` is valid; `data` is a valid readable slice for the duration of + // the call. + let ret = unsafe { + libc::write( + self.local_fd, + data.as_ptr().cast::(), + data.len(), + ) + }; + if ret < 0 { + Err(IoError::last_os_error()) + } else { + Ok(ret.cast_unsigned()) + } + } + + fn flush(&mut self) -> Result<(), IoError> { + Ok(()) + } + } + + impl WriteVolatile for SeqpacketTestStream { + fn write_volatile( + &mut self, + buf: &VolatileSlice, + ) -> Result { + let mut tmp = vec![0u8; buf.len()]; + buf.copy_to(&mut tmp); + // SAFETY: `local_fd` is valid; `tmp` is a valid readable buffer for the duration of + // the call. + let ret = unsafe { + libc::write( + self.local_fd, + tmp.as_ptr().cast::(), + tmp.len(), + ) + }; + if ret < 0 { + Err(VolatileMemoryError::IOError(IoError::last_os_error())) + } else { + Ok(ret.cast_unsigned()) + } + } + } + + impl VsockConnectionBackend for SeqpacketTestStream {} + + // EOM bit as defined in packet.rs + const VIRTIO_VSOCK_SEQ_EOM: u32 = 1 << 0; + + // Creates an established seqpacket connection backed by `stream`. + // `conn_buffer_size` sets the intermediate buffer used for large messages. + // Returns (connection, rx_pkt); the caller must keep _ctx alive for the duration. + fn make_established_seqpacket( + stream: SeqpacketTestStream, + conn_buffer_size: Option, + ) -> ( + VsockConnection, + VsockPacketRx, + TestContext, + ) { + let vsock_test_ctx = TestContext::new(); + let mut handler_ctx = vsock_test_ctx.create_event_handler_context(); + let mut rx_pkt = VsockPacketRx::new().unwrap(); + rx_pkt + .parse( + &vsock_test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap().unwrap(), + ) + .unwrap(); + + let mut conn = VsockConnection::::new_peer_init( + stream, + LOCAL_CID, + PEER_CID, + LOCAL_PORT, + PEER_PORT, + PEER_BUF_ALLOC, + VsockType::Seqpacket, + conn_buffer_size, + ); + // Drain the initial RESPONSE to reach Established state. + assert!(conn.has_pending_rx()); + conn.recv_pkt(&mut rx_pkt).unwrap(); + assert_eq!(rx_pkt.hdr.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(conn.state, ConnState::Established); + + (conn, rx_pkt, vsock_test_ctx) + } + + // Seqpacket: a small message (fits in one RX descriptor) is delivered in a single recv_pkt + // call with EOM set and should_retrigger=false. + #[test] + fn test_seqpacket_recv_small_message() { + let stream = SeqpacketTestStream::new(); + stream.push_message(b"hello"); + let (mut conn, mut rx_pkt, _ctx) = make_established_seqpacket(stream, None); + + conn.notify(EventSet::IN); + assert!(conn.has_pending_rx()); + + let res = conn.recv_pkt(&mut rx_pkt).unwrap(); + + assert_eq!(rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); + assert_eq!(rx_pkt.hdr.len(), 5); + assert_eq!(res.bytes_read, 5); + assert!(!res.should_retrigger); + // EOM flag must be set: this is the end of the seqpacket message. + assert_ne!(rx_pkt.hdr.flags() & VIRTIO_VSOCK_SEQ_EOM, 0); + // No further pending RX after a complete small message. + assert!(!conn.has_pending_rx()); + } + + // Seqpacket: a message larger than the RX descriptor buffer (4096 bytes) is split across + // two recv_pkt calls. The first call sets should_retrigger=true and leaves EOM clear; + // the second call delivers the remainder with EOM set. + #[test] + fn test_seqpacket_recv_large_message() { + const BUF_SIZE: usize = 4096; // matches the test descriptor size + const MSG_LEN: usize = BUF_SIZE + 512; + + let stream = SeqpacketTestStream::new(); + stream.push_message(&vec![0xABu8; MSG_LEN]); + let (mut conn, mut rx_pkt, _ctx) = make_established_seqpacket(stream, None); + + conn.notify(EventSet::IN); + assert!(conn.has_pending_rx()); + + // First call: fills the descriptor (4096 bytes), does not set EOM. + let res1 = conn.recv_pkt(&mut rx_pkt).unwrap(); + assert_eq!(rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); + assert_eq!(res1.bytes_read, u32::try_from(BUF_SIZE).unwrap()); + assert!(res1.should_retrigger); + assert_eq!(rx_pkt.hdr.flags() & VIRTIO_VSOCK_SEQ_EOM, 0); + // Connection must still have pending RX for the remainder. + assert!(conn.has_pending_rx()); + + // Second call: delivers the remaining 512 bytes with EOM set. + let res2 = conn.recv_pkt(&mut rx_pkt).unwrap(); + assert_eq!(rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); + assert_eq!(res2.bytes_read, 512); + assert!(!res2.should_retrigger); + assert_ne!(rx_pkt.hdr.flags() & VIRTIO_VSOCK_SEQ_EOM, 0); + assert!(!conn.has_pending_rx()); + } + + // Seqpacket: a message that exactly fills the RX descriptor is handled in one call, + // as a "small" packet (not buffered), with EOM set. + #[test] + fn test_seqpacket_recv_exact_buf_size_message() { + const BUF_SIZE: usize = 4096; + + let stream = SeqpacketTestStream::new(); + stream.push_message(&vec![0x42u8; BUF_SIZE]); + let (mut conn, mut rx_pkt, _ctx) = make_established_seqpacket(stream, None); + + conn.notify(EventSet::IN); + + let res = conn.recv_pkt(&mut rx_pkt).unwrap(); + + assert_eq!(rx_pkt.hdr.op(), uapi::VSOCK_OP_RW); + assert_eq!(res.bytes_read, u32::try_from(BUF_SIZE).unwrap()); + assert!(!res.should_retrigger); + assert_ne!(rx_pkt.hdr.flags() & VIRTIO_VSOCK_SEQ_EOM, 0); + assert!(!conn.has_pending_rx()); + } + + // Seqpacket: a message too large to fit in the intermediate connection buffer returns + // a MessageTooLong error and kills the connection (RST). + #[test] + fn test_seqpacket_recv_message_too_long() { + // Use a tiny intermediate buffer so a message slightly larger than the RX descriptor + // (4097 bytes > 4096 buf_size) exceeds it. + const SMALL_BUF: usize = 128; + const MSG_LEN: usize = 4097; // > buf_size (4096) so large-packet path is taken + + let stream = SeqpacketTestStream::new(); + stream.push_message(&vec![0u8; MSG_LEN]); + let (mut conn, mut rx_pkt, _ctx) = make_established_seqpacket(stream, Some(SMALL_BUF)); + + conn.notify(EventSet::IN); + + // recv_pkt should not propagate the error; instead it emits an RST packet. + let res = conn.recv_pkt(&mut rx_pkt).unwrap(); + assert_eq!(rx_pkt.hdr.op(), uapi::VSOCK_OP_RST); + assert_eq!(res.bytes_read, 0); + } + #[test] fn test_peer_credit_misbehavior() { let mut ctx = CsmTestContext::new_established(); diff --git a/src/vmm/src/devices/virtio/vsock/device.rs b/src/vmm/src/devices/virtio/vsock/device.rs index 4f3ab743f2e..66f545df50d 100644 --- a/src/vmm/src/devices/virtio/vsock/device.rs +++ b/src/vmm/src/devices/virtio/vsock/device.rs @@ -32,7 +32,9 @@ use super::packet::{VSOCK_PKT_HDR_SIZE, VsockPacketRx, VsockPacketTx}; use super::{VsockBackend, defs}; use crate::devices::virtio::ActivateError; use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice, VirtioDeviceType}; -use crate::devices::virtio::generated::virtio_config::{VIRTIO_F_IN_ORDER, VIRTIO_F_VERSION_1}; +use crate::devices::virtio::generated::virtio_config::{ + VIRTIO_F_IN_ORDER, VIRTIO_F_VERSION_1, VIRTIO_VSOCK_F_SEQPACKET, +}; use crate::devices::virtio::queue::{InvalidAvailIdx, Queue as VirtQueue}; use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::devices::virtio::vsock::VsockError; @@ -52,8 +54,11 @@ pub(crate) const VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: u32 = 0; /// - VIRTIO_F_VERSION_1: the device conforms to at least version 1.0 of the VirtIO spec. /// - VIRTIO_F_IN_ORDER: the device returns used buffers in the same order that the driver makes /// them available. -pub(crate) const AVAIL_FEATURES: u64 = - (1 << VIRTIO_F_VERSION_1 as u64) | (1 << VIRTIO_F_IN_ORDER as u64); +/// - VIRTIO_VSOCK_F_SEQPACKET: the device supports vsock connections backed by seqpacket +/// sockets. +pub(crate) const AVAIL_FEATURES: u64 = (1 << VIRTIO_F_VERSION_1 as u64) + | (1 << VIRTIO_F_IN_ORDER as u64) + | (1 << VIRTIO_VSOCK_F_SEQPACKET as u64); /// Structure representing the vsock device. #[derive(Debug)] @@ -162,47 +167,45 @@ where let queue = &mut self.queues[RXQ_INDEX]; let mut have_used = false; + let mut should_retrigger = false; while let Some(head) = queue.pop()? { let index = head.index; let used_len = match self.rx_packet.parse(mem, head) { - Ok(()) => { - if self.backend.recv_pkt(&mut self.rx_packet).is_ok() { - match self.rx_packet.commit_hdr() { - // This addition cannot overflow, because packet length - // is previously validated against `MAX_PKT_BUF_SIZE` - // bound as part of `commit_hdr()`. - Ok(()) => VSOCK_PKT_HDR_SIZE + self.rx_packet.hdr.len(), - Err(err) => { - warn!( - "vsock: Error writing packet header to guest memory: \ - {:?}.Discarding the package.", - err - ); - 0 - } - } - } else { - // We are using a consuming iterator over the virtio buffers, so, if we - // can't fill in this buffer, we'll need to undo the - // last iterator step. + Ok(()) => match self.backend.recv_pkt(&mut self.rx_packet) { + Err(_) => { queue.undo_pop(); break; } - } + Ok(read_res) => { + should_retrigger = read_res.should_retrigger; + self.rx_packet + .commit_hdr() + .map(|_| VSOCK_PKT_HDR_SIZE + read_res.bytes_read) + .unwrap_or_else(|err| { + warn!("vsock: Error writing packet header: {:?}. Discarding.", err); + 0 + }) + } + }, Err(err) => { warn!("vsock: RX queue error: {:?}. Discarding the package.", err); 0 } }; - have_used = true; queue.add_used(index, used_len).unwrap_or_else(|err| { error!("Failed to add available descriptor {}: {}", index, err) }); + + // we received more than the rx packet size + // trigger another loop iteration to consume from the temp buffer + if should_retrigger { + continue; + } } - queue.advance_used_ring_idx(); + queue.advance_used_ring_idx(); Ok(have_used) } diff --git a/src/vmm/src/devices/virtio/vsock/mod.rs b/src/vmm/src/devices/virtio/vsock/mod.rs index cc9f7746580..779bb6793d0 100644 --- a/src/vmm/src/devices/virtio/vsock/mod.rs +++ b/src/vmm/src/devices/virtio/vsock/mod.rs @@ -22,7 +22,7 @@ mod unix; use std::os::unix::io::AsRawFd; -use vm_memory::GuestMemoryError; +use vm_memory::{GuestMemoryError, VolatileMemoryError}; use vmm_sys_util::epoll::EventSet; pub use self::defs::VSOCK_DEV_ID; @@ -32,6 +32,8 @@ pub use self::unix::{VsockUnixBackend, VsockUnixBackendError}; use super::iov_deque::IovDequeError; use crate::devices::virtio::iovec::IoVecError; use crate::devices::virtio::persist::PersistError as VirtioStateError; +use crate::devices::virtio::vsock::unix::ReadResult; +use crate::vmm_config::vsock::VsockType; mod defs { use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; @@ -84,8 +86,10 @@ mod defs { /// Vsock packet type. /// Defined in `/include/uapi/linux/virtio_vsock.h`. /// - /// Stream / connection-oriented packet (the only currently valid type). + /// Stream / connection-oriented packet. pub const VSOCK_TYPE_STREAM: u16 = 1; + /// Seqpacket based connection + pub const VSOCK_TYPE_SEQPACKET: u16 = 2; pub const VSOCK_HOST_CID: u64 = 2; } @@ -128,6 +132,10 @@ pub enum VsockError { IovDeque(IovDequeError), /// Tried to push to full IovDeque. IovDequeOverflow, + /// Message too big for the intermediate connection buffer. buffer length {0}, incoming size {1} + MessageTooLong(u32, u32), + /// Encountered an error while processing a volatile memory read/write. + VolatileMemory(VolatileMemoryError) } impl From for VsockError { @@ -166,7 +174,7 @@ pub trait VsockEpollListener: AsRawFd { /// - `send_pkt(&pkt)` will fetch data from `pkt`, and place it into the channel. pub trait VsockChannel { /// Read/receive an incoming packet from the channel. - fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError>; + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result; /// Write/send a packet through the channel. fn send_pkt(&mut self, pkt: &VsockPacketTx) -> Result<(), VsockError>; diff --git a/src/vmm/src/devices/virtio/vsock/packet.rs b/src/vmm/src/devices/virtio/vsock/packet.rs index 7253f41ce76..24e74138c88 100644 --- a/src/vmm/src/devices/virtio/vsock/packet.rs +++ b/src/vmm/src/devices/virtio/vsock/packet.rs @@ -78,6 +78,9 @@ pub struct VsockPacketHeader { fwd_cnt: u32, } +const VIRTIO_VSOCK_SEQ_EOM: u32 = 1 << 0; +const VIRTIO_VSOCK_SEQ_EOR: u32 = 1 << 1; + impl VsockPacketHeader { pub fn src_cid(&self) -> u64 { u64::from_le(self.src_cid) @@ -133,6 +136,16 @@ impl VsockPacketHeader { self } + pub fn set_msg_eom(&mut self) -> &mut Self { + self.flags |= VIRTIO_VSOCK_SEQ_EOM; + self + } + + pub fn set_msg_eor(&mut self) -> &mut Self { + self.flags |= VIRTIO_VSOCK_SEQ_EOR; + self + } + pub fn op(&self) -> u16 { u16::from_le(self.op) } diff --git a/src/vmm/src/devices/virtio/vsock/persist.rs b/src/vmm/src/devices/virtio/vsock/persist.rs index 42909b58ddb..176408beac8 100644 --- a/src/vmm/src/devices/virtio/vsock/persist.rs +++ b/src/vmm/src/devices/virtio/vsock/persist.rs @@ -14,6 +14,7 @@ use crate::devices::virtio::persist::VirtioDeviceState; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; use crate::devices::virtio::transport::VirtioInterrupt; use crate::snapshot::Persist; +use crate::vmm_config::vsock::VsockType; use crate::vstate::memory::GuestMemoryMmap; /// The Vsock serializable state. @@ -40,6 +41,8 @@ pub struct VsockBackendState { pub uds_path: String, /// The last used host-side port. pub local_port_last: u32, + pub vsock_type: VsockType, + pub conn_buffer_size: Option, } /// A helper structure that holds the constructor arguments for VsockUnixBackend @@ -67,6 +70,8 @@ impl Persist<'_> for VsockUnixBackend { VsockBackendState { uds_path: self.host_sock_path.clone(), local_port_last: self.local_port_last, + vsock_type: self.vsock_type.clone(), + conn_buffer_size: self.conn_buffer_size, } } @@ -74,7 +79,12 @@ impl Persist<'_> for VsockUnixBackend { constructor_args: Self::ConstructorArgs, state: &Self::State, ) -> Result { - let mut backend = Self::new(constructor_args.cid, state.uds_path.clone())?; + let mut backend = Self::new( + constructor_args.cid, + state.uds_path.clone(), + state.vsock_type.clone(), + state.conn_buffer_size, + )?; backend.local_port_last = state.local_port_last; Ok(backend) } @@ -137,6 +147,8 @@ pub(crate) mod tests { VsockBackendState { uds_path: "test".to_owned(), local_port_last: 0xdeadbeef, + vsock_type: VsockType::Stream, + conn_buffer_size: None, } } diff --git a/src/vmm/src/devices/virtio/vsock/test_utils.rs b/src/vmm/src/devices/virtio/vsock/test_utils.rs index 3d4ab704975..04ead187dc5 100644 --- a/src/vmm/src/devices/virtio/vsock/test_utils.rs +++ b/src/vmm/src/devices/virtio/vsock/test_utils.rs @@ -17,6 +17,7 @@ use crate::devices::virtio::test_utils::{VirtQueue as GuestQ, default_interrupt} use crate::devices::virtio::transport::VirtioInterrupt; use crate::devices::virtio::vsock::device::{RXQ_INDEX, TXQ_INDEX}; use crate::devices::virtio::vsock::packet::VSOCK_PKT_HDR_SIZE; +use crate::devices::virtio::vsock::unix::ReadResult; use crate::devices::virtio::vsock::{ Vsock, VsockBackend, VsockChannel, VsockEpollListener, VsockError, }; @@ -65,7 +66,7 @@ impl Default for TestBackend { } impl VsockChannel for TestBackend { - fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> { + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result { let cool_buf = [0xDu8, 0xE, 0xA, 0xD, 0xB, 0xE, 0xE, 0xF]; match self.rx_err.take() { None => { @@ -78,7 +79,7 @@ impl VsockChannel for TestBackend { .unwrap(); } self.rx_ok_cnt += 1; - Ok(()) + Ok(ReadResult::new(buf_size, false)) } Some(err) => Err(err), } diff --git a/src/vmm/src/devices/virtio/vsock/unix/mod.rs b/src/vmm/src/devices/virtio/vsock/unix/mod.rs index 25fef274fc6..8bc057f5b27 100644 --- a/src/vmm/src/devices/virtio/vsock/unix/mod.rs +++ b/src/vmm/src/devices/virtio/vsock/unix/mod.rs @@ -10,10 +10,16 @@ mod muxer; mod muxer_killq; mod muxer_rxq; +mod seqpacket; +use std::io::{self, Read, Write}; +use std::os::fd::AsRawFd; +use std::os::unix::net::UnixStream; pub use muxer::VsockMuxer as VsockUnixBackend; +use vm_memory::io::{ReadVolatile, WriteVolatile}; use crate::devices::virtio::vsock::csm::VsockConnectionBackend; +use crate::devices::virtio::vsock::unix::seqpacket::SeqpacketConn; mod defs { /// Maximum number of established connections that we can handle. @@ -47,6 +53,101 @@ pub enum VsockUnixBackendError { TooManyConnections, } -type MuxerConnection = super::csm::VsockConnection; +#[derive(Debug)] +pub enum ConnBackend { + Stream(UnixStream), + Seqpacket(SeqpacketConn), +} + +macro_rules! forward_to_inner { + ($self:ident, $method:ident $(, $args:expr )* ) => { + match $self { + ConnBackend::Stream(inner) => inner.$method($($args),*), + ConnBackend::Seqpacket(inner) => inner.$method($($args),*), + } + }; +} + +impl Read for ConnBackend { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + forward_to_inner!(self, read, buf) + } +} + +impl AsRawFd for ConnBackend { + fn as_raw_fd(&self) -> i32 { + forward_to_inner!(self, as_raw_fd) + } +} + +impl ReadVolatile for ConnBackend { + fn read_volatile( + &mut self, + buf: &mut vm_memory::VolatileSlice, + ) -> Result { + forward_to_inner!(self, read_volatile, buf) + } +} + +impl WriteVolatile for ConnBackend { + fn write_volatile( + &mut self, + buf: &vm_memory::VolatileSlice, + ) -> Result { + forward_to_inner!(self, write_volatile, buf) + } +} + +impl Write for ConnBackend { + fn write(&mut self, buf: &[u8]) -> io::Result { + forward_to_inner!(self, write, buf) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +pub trait IncomingLength { + fn incoming_len(&mut self) -> Result; +} + +impl IncomingLength for B { + fn incoming_len(&mut self) -> Result { + let fd = self.as_raw_fd(); + // the maximum message size 256 bytes anyways + let mut peek_buf = [0u8; 1]; + // SAFETY: `fd` is a valid file descriptor for the duration of this call, and `peek_buf` + // is a valid single-byte buffer. MSG_PEEK | MSG_TRUNC returns the message size without + // consuming it. + let msg_size = unsafe { + libc::recv( + fd, + peek_buf.as_mut_ptr().cast(), + 1, + libc::MSG_PEEK | libc::MSG_TRUNC, + ) + }; + if msg_size < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(msg_size.cast_unsigned()) + } + } +} + +#[derive(Default, Debug)] +pub struct ReadResult { + pub bytes_read: u32, + pub should_retrigger: bool, +} + +impl ReadResult { + pub fn new(bytes_read: u32, should_retrigger: bool) -> Self { + ReadResult { + bytes_read, + should_retrigger, + } + } +} -impl VsockConnectionBackend for std::os::unix::net::UnixStream {} +impl VsockConnectionBackend for ConnBackend {} diff --git a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs index 4e0c945112a..9725e0c2f83 100644 --- a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs +++ b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs @@ -33,7 +33,8 @@ use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::io::Read; -use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::fd::FromRawFd; +use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; @@ -43,10 +44,15 @@ use super::super::defs::uapi; use super::super::{VsockBackend, VsockChannel, VsockEpollListener, VsockError}; use super::muxer_killq::MuxerKillQ; use super::muxer_rxq::MuxerRxQ; -use super::{MuxerConnection, VsockUnixBackendError, defs}; +use super::{VsockUnixBackendError, defs}; +use crate::devices::virtio::vsock::csm::VsockConnection; +use crate::devices::virtio::vsock::defs::uapi::{VSOCK_TYPE_SEQPACKET, VSOCK_TYPE_STREAM}; use crate::devices::virtio::vsock::metrics::METRICS; use crate::devices::virtio::vsock::packet::{VsockPacketRx, VsockPacketTx}; +use crate::devices::virtio::vsock::unix::seqpacket::{SeqpacketConn, SeqpacketListener, Socket}; +use crate::devices::virtio::vsock::unix::{ConnBackend, ReadResult}; use crate::logger::{IncMetric, debug, error, info, warn}; +use crate::vmm_config::vsock::VsockType; /// A unique identifier of a `MuxerConnection` object. Connections are stored in a hash map, /// keyed by a `ConnMapKey` object. @@ -76,7 +82,7 @@ enum EpollListener { HostSock, /// A listener interested in reading host `connect ` commands from a freshly /// connected host socket. - LocalStream(UnixStream), + LocalStream(ConnBackend), } /// The vsock connection multiplexer. @@ -85,7 +91,9 @@ pub struct VsockMuxer { /// Guest CID. cid: u64, /// A hash map used to store the active connections. - conn_map: HashMap, + conn_map: HashMap>, + /// the underlying host socket file descriptor type wrapper + host_sock: Box, /// A hash map used to store epoll event listeners / handlers. listener_map: HashMap, /// The RX queue. Items in this queue are consumed by `VsockMuxer::recv_pkt()`, and @@ -95,8 +103,6 @@ pub struct VsockMuxer { rxq: MuxerRxQ, /// A queue used for terminating connections that are taking too long to shut down. killq: MuxerKillQ, - /// The Unix socket, through which host-initiated connections are accepted. - host_sock: UnixListener, /// The file system path of the host-side Unix socket. This is used to figure out the path /// to Unix sockets listening on specific ports. I.e. `"_"`. pub(crate) host_sock_path: String, @@ -106,7 +112,6 @@ pub struct VsockMuxer { /// ports to host-initiated connections. local_port_set: HashSet, /// The last used host-side port. - /// /// Local ports are allocated in a round-robin fashion within the range [1 << 30, 1 << 31). /// There should be no inherent technical requirement for this specific range. But the range /// provides 1 billion available ports, making port collisions unlikely. In addition, the @@ -114,6 +119,10 @@ pub struct VsockMuxer { /// This appears to have been a design decision dating back to the initial introduction of the /// vsock implementation. pub(crate) local_port_last: u32, + /// The type of the socket (stream or seqpacket) + pub(crate) vsock_type: VsockType, + /// Length of the intermediate connection buffer + pub(crate) conn_buffer_size: Option, } impl VsockChannel for VsockMuxer { @@ -122,7 +131,7 @@ impl VsockChannel for VsockMuxer { /// Retuns: /// - `Ok(())`: `pkt` has been successfully filled in; or /// - `Err(VsockError::NoData)`: there was no available data with which to fill in the packet. - fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result<(), VsockError> { + fn recv_pkt(&mut self, pkt: &mut VsockPacketRx) -> Result { // We'll look for instructions on how to build the RX packet in the RX queue. If the // queue is empty, that doesn't necessarily mean we don't have any pending RX, since // the queue might be out-of-sync. If that's the case, we'll attempt to sync it first, @@ -145,12 +154,15 @@ impl VsockChannel for VsockMuxer { .set_src_port(local_port) .set_dst_port(peer_port) .set_len(0) - .set_type(uapi::VSOCK_TYPE_STREAM) .set_flags(0) .set_buf_alloc(0) .set_fwd_cnt(0); + match self.vsock_type { + VsockType::Seqpacket => pkt.hdr.set_type(VSOCK_TYPE_SEQPACKET), + VsockType::Stream => pkt.hdr.set_type(VSOCK_TYPE_STREAM), + }; self.rxq.pop().unwrap(); - return Ok(()); + return Ok(ReadResult::default()); } // We'll defer building the packet to this connection, since it has something @@ -181,7 +193,18 @@ impl VsockChannel for VsockMuxer { } debug!("vsock muxer: RX pkt: {:?}", pkt.hdr); - return Ok(()); + match res { + Ok(read_res) => { + // the read was buffered into an intermediate vector and this + // means there is still data to process but no fd event will + // kick off. manually push a a PendingRx queue entry + if read_res.should_retrigger { + self.rxq.push(rx); + } + return Ok(read_res); + } + Err(e) => return Err(e), + } } } @@ -208,9 +231,11 @@ impl VsockChannel for VsockMuxer { pkt.hdr ); - // If this packet has an unsupported type (!=stream), we must send back an RST. + // If this packet has an unsupported type (!=stream or seqpacket), we must send back an RST. // - if pkt.hdr.type_() != uapi::VSOCK_TYPE_STREAM { + if pkt.hdr.type_() != uapi::VSOCK_TYPE_STREAM + && pkt.hdr.type_() != uapi::VSOCK_TYPE_SEQPACKET + { self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port()); return Ok(()); } @@ -309,12 +334,28 @@ impl VsockBackend for VsockMuxer {} impl VsockMuxer { /// Muxer constructor. - pub fn new(cid: u64, host_sock_path: String) -> Result { + pub fn new( + cid: u64, + host_sock_path: String, + vsock_type: VsockType, + conn_buffer_size: Option, + ) -> Result { // Open/bind on the host Unix socket, so we can accept host-initiated // connections. - let host_sock = UnixListener::bind(&host_sock_path) - .and_then(|sock| sock.set_nonblocking(true).map(|_| sock)) - .map_err(VsockUnixBackendError::UnixBind)?; + let host_sock: Box = match vsock_type { + VsockType::Seqpacket => { + // we don't need to set non blocking here because by default we pass the flags for a non blocking socket + let sock = SeqpacketListener::bind(&host_sock_path) + .map_err(VsockUnixBackendError::UnixBind)?; + Box::new(sock) + } + VsockType::Stream => { + let sock = UnixListener::bind(&host_sock_path) + .and_then(|sock| sock.set_nonblocking(true).map(|_| sock)) + .map_err(VsockUnixBackendError::UnixBind)?; + Box::new(sock) + } + }; let mut muxer = Self { cid, @@ -327,6 +368,8 @@ impl VsockMuxer { killq: MuxerKillQ::new(), local_port_last: (1u32 << 30) - 1, local_port_set: HashSet::with_capacity(defs::MAX_CONNECTIONS), + vsock_type, + conn_buffer_size, }; // Listen on the host initiated socket, for incoming connections. @@ -339,6 +382,11 @@ impl VsockMuxer { &self.host_sock_path } + /// Return the type of the underlying socket. (stream or seqpacket) + pub fn vsock_type(&self) -> &VsockType { + &self.vsock_type + } + /// Handle/dispatch an epoll event to its listener. fn handle_event(&mut self, fd: RawFd, event_set: EventSet) { debug!( @@ -372,12 +420,6 @@ impl VsockMuxer { self.host_sock .accept() .map_err(VsockUnixBackendError::UnixAccept) - .and_then(|(stream, _)| { - stream - .set_nonblocking(true) - .map(|_| stream) - .map_err(VsockUnixBackendError::UnixAccept) - }) .and_then(|stream| { // Before forwarding this connection to a listening AF_VSOCK socket on // the guest side, we need to know the destination port. We'll read @@ -394,7 +436,9 @@ impl VsockMuxer { // "connect" command that we're expecting. Some(EpollListener::LocalStream(_)) => { if let Some(EpollListener::LocalStream(mut stream)) = self.remove_listener(fd) { - Self::read_local_stream_port(&mut stream) + // SAFETY: Safe because the fd is valid and we own it (removed from + // listener_map). + Self::read_local_stream_port(&mut stream, &self.vsock_type) .map(|peer_port| (self.allocate_local_port(), peer_port)) .and_then(|(local_port, peer_port)| { self.add_connection( @@ -402,18 +446,20 @@ impl VsockMuxer { local_port, peer_port, }, - MuxerConnection::new_local_init( + VsockConnection::new_local_init( stream, uapi::VSOCK_HOST_CID, self.cid, local_port, peer_port, + self.vsock_type.clone(), + self.conn_buffer_size, ), ) }) .unwrap_or_else(|err| { info!("vsock: error adding local-init connection: {:?}", err); - }) + }); } } @@ -428,28 +474,43 @@ impl VsockMuxer { } /// Parse a host "connect" command, and extract the destination vsock port. - fn read_local_stream_port(stream: &mut UnixStream) -> Result { + fn read_local_stream_port( + stream: &mut dyn Read, + vsock_type: &VsockType, + ) -> Result { let mut buf = [0u8; 32]; - // This is the minimum number of bytes that we should be able to read, when parsing a - // valid connection request. I.e. `b"connect 0\n".len()`. - const MIN_READ_LEN: usize = 10; - - // Bring in the minimum number of bytes that we should be able to read. - stream - .read_exact(&mut buf[..MIN_READ_LEN]) - .map_err(VsockUnixBackendError::UnixRead)?; - - // Now, finish reading the destination port number, by bringing in one byte at a time, - // until we reach an EOL terminator (or our buffer space runs out). Yeah, not - // particularly proud of this approach, but it will have to do for now. - let mut blen = MIN_READ_LEN; - while buf[blen - 1] != b'\n' && blen < buf.len() { - stream - .read_exact(&mut buf[blen..=blen]) - .map_err(VsockUnixBackendError::UnixRead)?; - blen += 1; - } + let blen = match vsock_type { + VsockType::Seqpacket => { + // Seqpacket delivers the entire message atomically, so a single read gets it + // all. Using read_exact would silently truncate messages longer than + // MIN_READ_LEN bytes (e.g. "connect 525\n"), discarding the remainder. + stream + .read(&mut buf) + .map_err(VsockUnixBackendError::UnixRead)? + } + VsockType::Stream => { + // This is the minimum number of bytes that we should be able to read, when + // parsing a valid connection request. I.e. `b"connect 0\n".len()`. + const MIN_READ_LEN: usize = 10; + + // Bring in the minimum number of bytes that we should be able to read. + stream + .read_exact(&mut buf[..MIN_READ_LEN]) + .map_err(VsockUnixBackendError::UnixRead)?; + + // Now, finish reading the destination port number, by bringing in one byte at + // a time, until we reach an EOL terminator (or our buffer space runs out). + let mut blen = MIN_READ_LEN; + while buf[blen - 1] != b'\n' && blen < buf.len() { + stream + .read_exact(&mut buf[blen..=blen]) + .map_err(VsockUnixBackendError::UnixRead)?; + blen += 1; + } + blen + } + }; let mut word_iter = std::str::from_utf8(&buf[..blen]) .map_err(|_| VsockUnixBackendError::InvalidPortRequest)? @@ -481,7 +542,7 @@ impl VsockMuxer { fn add_connection( &mut self, key: ConnMapKey, - conn: MuxerConnection, + conn: VsockConnection, ) -> Result<(), VsockUnixBackendError> { // We might need to make room for this new connection, so let's sweep the kill queue // first. It's fine to do this here because: @@ -618,27 +679,55 @@ impl VsockMuxer { /// RST packet will be scheduled for delivery to the guest. fn handle_peer_request_pkt(&mut self, pkt: &VsockPacketTx) { let port_path = format!("{}_{}", self.host_sock_path, pkt.hdr.dst_port()); - - UnixStream::connect(port_path) - .and_then(|stream| stream.set_nonblocking(true).map(|_| stream)) - .map_err(VsockUnixBackendError::UnixConnect) - .and_then(|stream| { - self.add_connection( - ConnMapKey { - local_port: pkt.hdr.dst_port(), - peer_port: pkt.hdr.src_port(), - }, - MuxerConnection::new_peer_init( - stream, - uapi::VSOCK_HOST_CID, - self.cid, - pkt.hdr.dst_port(), - pkt.hdr.src_port(), - pkt.hdr.buf_alloc(), - ), - ) - }) - .unwrap_or_else(|_| self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port())); + match self.vsock_type { + VsockType::Stream => { + UnixStream::connect(port_path) + .and_then(|stream| stream.set_nonblocking(true).map(|_| stream)) + .map_err(VsockUnixBackendError::UnixConnect) + .and_then(|stream| { + self.add_connection( + ConnMapKey { + local_port: pkt.hdr.dst_port(), + peer_port: pkt.hdr.src_port(), + }, + VsockConnection::::new_peer_init( + ConnBackend::Stream(stream), + uapi::VSOCK_HOST_CID, + self.cid, + pkt.hdr.dst_port(), + pkt.hdr.src_port(), + pkt.hdr.buf_alloc(), + VsockType::Stream, + None, + ), + ) + }) + .unwrap_or_else(|_| self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port())); + } + VsockType::Seqpacket => { + SeqpacketConn::connect(&port_path) + .map_err(VsockUnixBackendError::UnixConnect) + .and_then(|stream| { + self.add_connection( + ConnMapKey { + local_port: pkt.hdr.dst_port(), + peer_port: pkt.hdr.src_port(), + }, + VsockConnection::::new_peer_init( + ConnBackend::Seqpacket(stream), + uapi::VSOCK_HOST_CID, + self.cid, + pkt.hdr.dst_port(), + pkt.hdr.src_port(), + pkt.hdr.buf_alloc(), + VsockType::Seqpacket, + self.conn_buffer_size, + ), + ) + }) + .unwrap_or_else(|_| self.enq_rst(pkt.hdr.dst_port(), pkt.hdr.src_port())); + } + } } /// Perform an action that might mutate a connection's state. @@ -650,7 +739,7 @@ impl VsockMuxer { /// - kill the connection if an unrecoverable error occurs. fn apply_conn_mutation(&mut self, key: ConnMapKey, mut_fn: F) where - F: FnOnce(&mut MuxerConnection), + F: FnOnce(&mut VsockConnection), { if let Some(conn) = self.conn_map.get_mut(&key) { let had_rx = conn.has_pending_rx(); @@ -855,7 +944,7 @@ mod tests { ) .unwrap(); - let muxer = VsockMuxer::new(PEER_CID, get_file(name)).unwrap(); + let muxer = VsockMuxer::new(PEER_CID, get_file(name), VsockType::Stream, None).unwrap(); Self { _vsock_test_ctx: vsock_test_ctx, rx_pkt, diff --git a/src/vmm/src/devices/virtio/vsock/unix/muxer_killq.rs b/src/vmm/src/devices/virtio/vsock/unix/muxer_killq.rs index 17cc193d120..a607dcf7030 100644 --- a/src/vmm/src/devices/virtio/vsock/unix/muxer_killq.rs +++ b/src/vmm/src/devices/virtio/vsock/unix/muxer_killq.rs @@ -27,8 +27,10 @@ use std::collections::{HashMap, VecDeque}; use std::time::Instant; +use super::defs; use super::muxer::ConnMapKey; -use super::{MuxerConnection, defs}; +use crate::devices::virtio::vsock::csm::VsockConnection; +use crate::devices::virtio::vsock::unix::ConnBackend; /// A kill queue item, holding the connection key and the scheduled time for termination. #[derive(Debug, Clone, Copy)] @@ -66,7 +68,7 @@ impl MuxerKillQ { /// set to expire at some point in the future. /// Note: if more than `Self::SIZE` connections are found, the queue will be created in an /// out-of-sync state, and will be discarded after it is emptied. - pub fn from_conn_map(conn_map: &HashMap) -> Self { + pub fn from_conn_map(conn_map: &HashMap>) -> Self { let mut q_buf: Vec = Vec::with_capacity(Self::SIZE); let mut synced = true; for (key, conn) in conn_map.iter() { diff --git a/src/vmm/src/devices/virtio/vsock/unix/muxer_rxq.rs b/src/vmm/src/devices/virtio/vsock/unix/muxer_rxq.rs index 1b888dfa453..8bc989baf04 100644 --- a/src/vmm/src/devices/virtio/vsock/unix/muxer_rxq.rs +++ b/src/vmm/src/devices/virtio/vsock/unix/muxer_rxq.rs @@ -18,8 +18,10 @@ use std::collections::{HashMap, VecDeque}; use super::super::VsockChannel; +use super::defs; use super::muxer::{ConnMapKey, MuxerRx}; -use super::{MuxerConnection, defs}; +use crate::devices::virtio::vsock::csm::VsockConnection; +use crate::devices::virtio::vsock::unix::ConnBackend; /// The muxer RX queue. #[derive(Debug)] @@ -45,7 +47,7 @@ impl MuxerRxQ { /// Note: the resulting queue may still be desynchronized, if there are too many connections /// that have pending RX data. In that case, the muxer will first drain this queue, and /// then try again to build a synchronized one. - pub fn from_conn_map(conn_map: &HashMap) -> Self { + pub fn from_conn_map(conn_map: &HashMap>) -> Self { let mut q = VecDeque::new(); let mut synced = true; diff --git a/src/vmm/src/devices/virtio/vsock/unix/seqpacket.rs b/src/vmm/src/devices/virtio/vsock/unix/seqpacket.rs new file mode 100644 index 00000000000..78c7f508c0d --- /dev/null +++ b/src/vmm/src/devices/virtio/vsock/unix/seqpacket.rs @@ -0,0 +1,282 @@ +// Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#![allow(clippy::cast_possible_truncation)] + +use std::io; +use std::io::{Error, ErrorKind, Read, Write}; +use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; +use std::os::unix::net::UnixListener; + +use std::fmt::Debug; +use vm_memory::{ReadVolatile, VolatileMemoryError, WriteVolatile}; + +use crate::devices::virtio::vsock::unix::ConnBackend; + +#[derive(Debug)] +pub struct SeqpacketConn(std::os::fd::OwnedFd); + +impl SeqpacketConn { + pub fn connect(path: &str) -> Result { + let (addr, addr_len) = build_addr(path)?; + + // SAFETY: Valid flags and socket type. + let fd = + unsafe { libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC, 0) }; + if fd == -1 { + return Err(io::Error::last_os_error()); + } + + // Set non-blocking via FIONBIO ioctl (already allowed by seccomp filter) + let mut nonblocking: libc::c_int = 1; + // SAFETY: `fd` is valid (checked above); `nonblocking` is a valid int pointer. + let ret = unsafe { libc::ioctl(fd, libc::FIONBIO, &mut nonblocking) }; + if ret < 0 { + // SAFETY: `fd` is a valid open fd; closing it to avoid a leak. + unsafe { libc::close(fd) }; + return Err(io::Error::last_os_error()); + } + + // SAFETY: Valid file descriptor and errors checked. + unsafe { + if libc::connect( + fd, + (&addr as *const libc::sockaddr_un).cast::(), + addr_len, + ) == -1 + { + let err = io::Error::last_os_error(); + libc::close(fd); + return Err(err); + } + }; + + // SAFETY: Valid file descriptor and errors checked. + Ok(unsafe { SeqpacketConn(OwnedFd::from_raw_fd(fd)) }) + } +} + +impl AsRawFd for SeqpacketConn { + fn as_raw_fd(&self) -> i32 { + self.0.as_raw_fd() + } +} + +impl Read for SeqpacketConn { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let ptr = buf.as_mut_ptr().cast::(); + // SAFETY: The file descriptor is valid and open. The buffer pointer is valid for writing + // `buf.len()` bytes. + let received = + unsafe { libc::recv(self.0.as_raw_fd(), ptr, buf.len(), libc::MSG_NOSIGNAL) }; + if received < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(received.try_into().unwrap()) + } +} + +impl Write for SeqpacketConn { + fn write(&mut self, buf: &[u8]) -> io::Result { + let ptr = buf.as_ptr().cast::(); + let flags = libc::MSG_NOSIGNAL; + // SAFETY: The file descriptor is valid and open. The buffer pointer is valid for reading + // `buf.len()` bytes. + let sent = unsafe { libc::send(self.0.as_raw_fd(), ptr, buf.len(), flags) }; + if sent < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(sent.try_into().unwrap()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl ReadVolatile for SeqpacketConn { + fn read_volatile( + &mut self, + buf: &mut vm_memory::VolatileSlice, + ) -> Result { + let fd = self.0.as_raw_fd(); + let guard = buf.ptr_guard_mut(); + + let dst = guard.as_ptr().cast::(); + + // SAFETY: Rust's I/O safety invariants ensure that BorrowedFd contains a valid file + // descriptor`. The memory pointed to by `dst` is valid for writes of length + // `buf.len() by the invariants upheld by the constructor of `VolatileSlice`. + let bytes_read = unsafe { libc::recv(fd, dst, buf.len(), 0) }; + + if bytes_read < 0 { + // We don't know if a partial read might have happened, so mark everything as dirty + buf.bitmap().mark_dirty(0, buf.len()); + Err(VolatileMemoryError::IOError(std::io::Error::last_os_error())) + } else { + let bytes_read = bytes_read.try_into().unwrap(); + buf.bitmap().mark_dirty(0, bytes_read); + Ok(bytes_read) + } + } +} + +impl WriteVolatile for SeqpacketConn { + fn write_volatile( + &mut self, + buf: &vm_memory::VolatileSlice, + ) -> Result { + let fd = self.0.as_raw_fd(); + let guard = buf.ptr_guard(); + + let src = guard.as_ptr().cast::(); + + // SAFETY: Rust's I/O safety invariants ensure that BorrowedFd contains a valid file + // descriptor`. The memory pointed to by `src` is valid for reads of length + // `buf.len() by the invariants upheld by the constructor of `VolatileSlice`. + let bytes_written = unsafe { libc::send(fd, src, buf.len(), libc::MSG_NOSIGNAL) }; + + if bytes_written < 0 { + Err(VolatileMemoryError::IOError(std::io::Error::last_os_error())) + } else { + Ok(bytes_written.try_into().unwrap()) + } + } +} + +#[derive(Debug)] +pub struct SeqpacketListener(OwnedFd); + +impl SeqpacketListener { + pub fn bind(path: &str) -> Result { + // SAFETY: Valid socket() parameters, error checked + let fd = unsafe { + libc::socket( + libc::AF_UNIX, + libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK, + 0, + ) + }; + + if fd == -1 { + return Err(io::Error::last_os_error()); + } + + // SAFETY: Transferring unique ownership of valid fd to OwnedFd + let ownedfd = unsafe { OwnedFd::from_raw_fd(fd) }; + + let (addr, addr_len) = build_addr(path)?; + + // SAFETY: Valid fd and addr pointer, closes fd on error + unsafe { + if libc::bind( + fd, + (&addr as *const libc::sockaddr_un).cast::(), + addr_len, + ) == -1 + { + let err = io::Error::last_os_error(); + return Err(err); + } + }; + + // SAFETY: Valid bound socket, closes fd on error + unsafe { + if libc::listen(fd, libc::SOMAXCONN) == -1 { + let err = io::Error::last_os_error(); + return Err(err); + } + }; + + Ok(SeqpacketListener(ownedfd)) + } +} + +impl AsRawFd for SeqpacketListener { + fn as_raw_fd(&self) -> i32 { + self.0.as_raw_fd() + } +} + +pub trait Socket: AsRawFd + Debug + Send { + fn accept(&self) -> Result; +} + +impl Socket for SeqpacketListener { + fn accept(&self) -> Result { + let flags = libc::SOCK_CLOEXEC; + let mut addr: libc::sockaddr_un = uninitialized_address(); + let mut addr_len: libc::socklen_t = std::mem::size_of_val(&addr) as libc::socklen_t; + + addr.sun_family = libc::AF_UNIX as libc::sa_family_t; + // SAFETY: Valid fd, errors checked. + let fd = unsafe { + libc::accept4( + self.0.as_raw_fd(), + (&mut addr as *mut libc::sockaddr_un).cast::(), + &mut addr_len, + flags, + ) + }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + + // Set non-blocking via FIONBIO ioctl (already allowed by seccomp filter) + let mut nonblocking: libc::c_int = 1; + // SAFETY: `fd` is valid (checked above); `nonblocking` is a valid int pointer. + let ret = unsafe { libc::ioctl(fd, libc::FIONBIO, &mut nonblocking) }; + if ret < 0 { + // SAFETY: `fd` is a valid open fd; closing it to avoid a leak. + unsafe { libc::close(fd) }; + return Err(io::Error::last_os_error()); + } + + // SAFETY: Transferring unique ownership of valid fd to OwnedFd + unsafe { + Ok(ConnBackend::Seqpacket(SeqpacketConn(OwnedFd::from_raw_fd( + fd, + )))) + } + } +} + +impl Socket for UnixListener { + fn accept(&self) -> Result { + let (conn, _) = self.accept()?; + conn.set_nonblocking(true)?; + Ok(ConnBackend::Stream(conn)) + } +} + +fn build_addr(path: &str) -> Result<(libc::sockaddr_un, u32), io::Error> { + let mut addr: libc::sockaddr_un = uninitialized_address(); + addr.sun_family = libc::AF_UNIX as _; + let max_addr = std::mem::size_of_val(&addr.sun_path); + if path.len() > std::mem::size_of_val(&addr.sun_path) { + return Err(Error::new( + ErrorKind::InvalidInput, + format!( + "the path has length higher than maximum allowed: {}, got: {}", + path.len(), + max_addr + ), + )); + }; + + // SAFETY: Bounded copy, non-overlapping pointers + unsafe { + std::ptr::copy_nonoverlapping( + path.as_ptr().cast::(), + addr.sun_path.as_mut_ptr(), + path.len().min(addr.sun_path.len()), + ); + }; + Ok((addr, std::mem::size_of::() as u32)) +} + +fn uninitialized_address() -> libc::sockaddr_un { + // SAFETY: sockaddr_un has no invalid bit patterns + unsafe { std::mem::zeroed() } +} diff --git a/src/vmm/src/resources.rs b/src/vmm/src/resources.rs index 0c9659e3376..25677fdf6c2 100644 --- a/src/vmm/src/resources.rs +++ b/src/vmm/src/resources.rs @@ -377,7 +377,7 @@ impl VmResources { /// Sets a vsock device to be attached when the VM starts. pub fn set_vsock_device(&mut self, config: VsockDeviceConfig) -> Result<(), VsockConfigError> { - self.vsock.insert(config) + self.vsock.insert(&config) } /// Builds an entropy device to be attached when the VM starts. diff --git a/src/vmm/src/rpc_interface.rs b/src/vmm/src/rpc_interface.rs index 891324ff14a..d56fd73c010 100644 --- a/src/vmm/src/rpc_interface.rs +++ b/src/vmm/src/rpc_interface.rs @@ -998,6 +998,7 @@ mod tests { use crate::mmds::data_store::MmdsVersion; use crate::seccomp::BpfThreadMap; use crate::vmm_config::snapshot::{MemBackendConfig, MemBackendType}; + use crate::vmm_config::vsock::VsockType; fn default_preboot<'a>( vm_resources: &'a mut VmResources, @@ -1279,6 +1280,8 @@ mod tests { vsock_id: Some(String::new()), guest_cid: 0, uds_path: String::new(), + vsock_type: VsockType::Stream, + conn_buffer_size: None, }, ))); check_unsupported(runtime_request(VmmAction::SetBalloonDevice( @@ -1289,6 +1292,8 @@ mod tests { vsock_id: Some(String::new()), guest_cid: 0, uds_path: String::new(), + vsock_type: VsockType::Stream, + conn_buffer_size: None, }, ))); check_unsupported(runtime_request(VmmAction::SetMmdsConfiguration( diff --git a/src/vmm/src/vmm_config/vsock.rs b/src/vmm/src/vmm_config/vsock.rs index fac773b90e1..6d2e08ff7f9 100644 --- a/src/vmm/src/vmm_config/vsock.rs +++ b/src/vmm/src/vmm_config/vsock.rs @@ -4,10 +4,14 @@ use std::convert::TryFrom; use std::sync::{Arc, Mutex}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use crate::devices::virtio::vsock::{Vsock, VsockError, VsockUnixBackend, VsockUnixBackendError}; +// A connection buffer needs to be equivalent to atleast 1 page of memory +const MIN_CONN_BUF: usize = 4 * 1024; +const MAX_CONN_BUF: usize = 256 * 1024; + type MutexVsockUnix = Arc>>; /// Errors associated with `NetworkInterfaceConfig`. @@ -19,6 +23,33 @@ pub enum VsockConfigError { CreateVsockDevice(VsockError), } +/// from vsock related requests. +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum VsockType { + /// A stream type socket. No message boundary preservation. + #[default] + Stream, + /// A seqpacket type socket. Message boundaries are denoted + /// by a MSG_EOM flag in the vsock header. + Seqpacket, +} + +fn deserialize_conn_buffer_size<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let v = Option::::deserialize(deserializer)?; + if let Some(n) = v + && !(MIN_CONN_BUF..=MAX_CONN_BUF).contains(&n) + { + return Err(serde::de::Error::custom(format!( + "conn_buffer_size is invalid (max {}), (min {})", + MAX_CONN_BUF, MIN_CONN_BUF + ))); + } + Ok(v) +} /// This struct represents the strongly typed equivalent of the json body /// from vsock related requests. #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] @@ -32,12 +63,22 @@ pub struct VsockDeviceConfig { pub guest_cid: u32, /// Path to local unix socket. pub uds_path: String, + #[serde(default)] + /// the type of the underlying socket + pub vsock_type: VsockType, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(deserialize_with = "deserialize_conn_buffer_size")] + /// the size of the intermediate connection buffer + pub conn_buffer_size: Option, } #[derive(Debug)] struct VsockAndUnixPath { vsock: MutexVsockUnix, uds_path: String, + vsock_type: VsockType, + conn_buffer_size: Option, } impl From<&VsockAndUnixPath> for VsockDeviceConfig { @@ -47,6 +88,8 @@ impl From<&VsockAndUnixPath> for VsockDeviceConfig { vsock_id: None, guest_cid: u32::try_from(vsock_lock.cid()).unwrap(), uds_path: vsock.uds_path.clone(), + vsock_type: vsock.vsock_type.clone(), + conn_buffer_size: vsock.conn_buffer_size, } } } @@ -57,6 +100,8 @@ impl From<&Vsock> for VsockDeviceConfig { vsock_id: None, // deprecated guest_cid: u32::try_from(vsock.cid()).unwrap(), uds_path: vsock.backend().host_sock_path().to_owned(), + vsock_type: vsock.backend().vsock_type.clone(), + conn_buffer_size: vsock.backend().conn_buffer_size, } } } @@ -75,20 +120,18 @@ impl VsockBuilder { /// Inserts an existing vsock device. pub fn set_device(&mut self, device: Arc>>) { + let device_inner = device.lock().expect("Poisoned lock"); self.inner = Some(VsockAndUnixPath { - uds_path: device - .lock() - .expect("Poisoned lock") - .backend() - .host_sock_path() - .to_owned(), + uds_path: device_inner.backend().host_sock_path().to_owned(), vsock: device.clone(), + vsock_type: device_inner.backend().vsock_type().clone(), + conn_buffer_size: device_inner.backend().conn_buffer_size, }); } /// Inserts a Unix backend Vsock in the store. /// If an entry already exists, it will overwrite it. - pub fn insert(&mut self, cfg: VsockDeviceConfig) -> Result<(), VsockConfigError> { + pub fn insert(&mut self, cfg: &VsockDeviceConfig) -> Result<(), VsockConfigError> { // Make sure to drop the old one and remove the socket before creating a new one. if let Some(existing) = self.inner.take() { std::fs::remove_file(existing.uds_path).map_err(VsockUnixBackendError::UnixBind)?; @@ -96,6 +139,8 @@ impl VsockBuilder { self.inner = Some(VsockAndUnixPath { uds_path: cfg.uds_path.clone(), vsock: Arc::new(Mutex::new(Self::create_unixsock_vsock(cfg)?)), + vsock_type: cfg.vsock_type.clone(), + conn_buffer_size: cfg.conn_buffer_size, }); Ok(()) } @@ -107,9 +152,14 @@ impl VsockBuilder { /// Creates a Vsock device from a VsockDeviceConfig. pub fn create_unixsock_vsock( - cfg: VsockDeviceConfig, + cfg: &VsockDeviceConfig, ) -> Result, VsockConfigError> { - let backend = VsockUnixBackend::new(u64::from(cfg.guest_cid), cfg.uds_path)?; + let backend = VsockUnixBackend::new( + u64::from(cfg.guest_cid), + cfg.uds_path.clone(), + cfg.vsock_type.clone(), + cfg.conn_buffer_size, + )?; Vsock::new(u64::from(cfg.guest_cid), backend).map_err(VsockConfigError::CreateVsockDevice) } @@ -133,6 +183,8 @@ pub(crate) mod tests { vsock_id: None, guest_cid: 3, uds_path: tmp_sock_file.as_path().to_str().unwrap().to_string(), + vsock_type: VsockType::default(), + conn_buffer_size: None, } } @@ -141,7 +193,7 @@ pub(crate) mod tests { let mut tmp_sock_file = TempFile::new().unwrap(); tmp_sock_file.remove().unwrap(); let vsock_config = default_config(&tmp_sock_file); - VsockBuilder::create_unixsock_vsock(vsock_config).unwrap(); + VsockBuilder::create_unixsock_vsock(&vsock_config).unwrap(); } #[test] @@ -151,13 +203,13 @@ pub(crate) mod tests { tmp_sock_file.remove().unwrap(); let mut vsock_config = default_config(&tmp_sock_file); - store.insert(vsock_config.clone()).unwrap(); + store.insert(&vsock_config.clone()).unwrap(); let vsock = store.get().unwrap(); assert_eq!(vsock.lock().unwrap().id(), VSOCK_DEV_ID); let new_cid = vsock_config.guest_cid + 1; vsock_config.guest_cid = new_cid; - store.insert(vsock_config).unwrap(); + store.insert(&vsock_config).unwrap(); let vsock = store.get().unwrap(); assert_eq!(vsock.lock().unwrap().cid(), u64::from(new_cid)); } @@ -168,7 +220,7 @@ pub(crate) mod tests { let mut tmp_sock_file = TempFile::new().unwrap(); tmp_sock_file.remove().unwrap(); let vsock_config = default_config(&tmp_sock_file); - vsock_builder.insert(vsock_config.clone()).unwrap(); + vsock_builder.insert(&vsock_config.clone()).unwrap(); let config = vsock_builder.config(); assert!(config.is_some()); @@ -182,8 +234,13 @@ pub(crate) mod tests { tmp_sock_file.remove().unwrap(); let vsock = Vsock::new( 0, - VsockUnixBackend::new(1, tmp_sock_file.as_path().to_str().unwrap().to_string()) - .unwrap(), + VsockUnixBackend::new( + 1, + tmp_sock_file.as_path().to_str().unwrap().to_string(), + VsockType::default(), + None, + ) + .unwrap(), ) .unwrap(); diff --git a/src/vmm/tests/integration_tests.rs b/src/vmm/tests/integration_tests.rs index 6a8a6992ba4..5c750613d12 100644 --- a/src/vmm/tests/integration_tests.rs +++ b/src/vmm/tests/integration_tests.rs @@ -28,7 +28,7 @@ use vmm::vmm_config::net::NetworkInterfaceConfig; use vmm::vmm_config::snapshot::{ CreateSnapshotParams, LoadSnapshotParams, MemBackendConfig, MemBackendType, SnapshotType, }; -use vmm::vmm_config::vsock::VsockDeviceConfig; +use vmm::vmm_config::vsock::{VsockDeviceConfig, VsockType}; use vmm::{DumpCpuConfigError, EventManager, FcExitCode, Vmm}; use vmm_sys_util::tempfile::TempFile; @@ -439,6 +439,8 @@ fn test_preboot_load_snap_disallowed_after_boot_resources() { vsock_id: Some(String::new()), guest_cid: 0, uds_path: String::new(), + vsock_type: VsockType::Stream, + conn_buffer_size: None, }); verify_load_snap_disallowed_after_boot_resources(req, "SetVsockDevice"); diff --git a/tests/conftest.py b/tests/conftest.py index 6b49e898ddc..2a4e1ce67ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -244,6 +244,18 @@ def bin_sysgenid_path(test_fc_session_root_path): yield sysgenid_helper_bin_path +@pytest.fixture(scope="session") +def bin_vsock_seqpacket_listener_path(test_fc_session_root_path): + """Build a simple vsock seqpacket server application.""" + vsock_seq_srv_bin_path = os.path.join(test_fc_session_root_path, "vsock_seq_server") + build_tools.gcc_compile( + "host_tools/vsock_seq_server.c", + vsock_seq_srv_bin_path, + extra_flags="-lpthread -O3", + ) + yield vsock_seq_srv_bin_path + + @pytest.fixture(scope="session") def bin_vmclock_path(test_fc_session_root_path): """Build a simple util for test VMclock device""" diff --git a/tests/framework/utils_vsock.py b/tests/framework/utils_vsock.py index 9561c1c26f2..edc8157cc0b 100644 --- a/tests/framework/utils_vsock.py +++ b/tests/framework/utils_vsock.py @@ -31,14 +31,17 @@ class HostEchoWorker(Thread): contents of `blob_path`. """ - def __init__(self, uds_path, blob_path): + def __init__(self, uds_path, blob_path, vsock_type=SOCK_STREAM): """.""" super().__init__() self.uds_path = uds_path self.blob_path = blob_path self.hash = None + self.vsock_type = vsock_type self.error = None - self.sock = _vsock_connect_to_guest(self.uds_path, ECHO_SERVER_PORT) + self.sock = _vsock_connect_to_guest( + self.uds_path, ECHO_SERVER_PORT, self.vsock_type + ) def run(self): """Thread code payload. @@ -65,7 +68,6 @@ def _run(self): buf = blob_file.read(BUF_SIZE) if not buf: break - sent = self.sock.send(buf) while sent < len(buf): sent += self.sock.send(buf[sent:]) @@ -110,7 +112,21 @@ def start_guest_echo_server(vm): return os.path.join(vm.jailer.chroot_path(), VSOCK_UDS_PATH) -def check_host_connections(uds_path, blob_path, blob_hash): +def start_seqpacket_echo_server(vm): + """Start a vsock seqpacket echo server in the microVM. + + Returns a UDS path to connect to the server. + """ + cmd = f"nohup /tmp/vsock_seq_server serve {ECHO_SERVER_PORT} af_vsock >/dev/null 2>&1 &" + vm.ssh.check_output(cmd) + + # Give the server time to initialise + time.sleep(1) + + return os.path.join(vm.jailer.chroot_path(), VSOCK_UDS_PATH) + + +def check_host_connections(uds_path, blob_path, blob_hash, vsock_type=SOCK_STREAM): """Test host-initiated connections. This will spawn `TEST_CONNECTION_COUNT` `HostEchoWorker` threads. @@ -121,7 +137,7 @@ def check_host_connections(uds_path, blob_path, blob_hash): workers = [] for _ in range(TEST_CONNECTION_COUNT): - worker = HostEchoWorker(uds_path, blob_path) + worker = HostEchoWorker(uds_path, blob_path, vsock_type) workers.append(worker) worker.start() @@ -132,6 +148,76 @@ def check_host_connections(uds_path, blob_path, blob_hash): assert wrk.hash == blob_hash +def check_guest_connections_seqpacket( + vm, server_port_path, server_bin_path, blob_path, blob_hash +): + """Test guest-initiated connections. + + This will start an echo server on the host (in its own thread), then + start `TEST_CONNECTION_COUNT` workers inside the guest VM, all + communicating with the echo server. + """ + port = server_port_path.split("_")[-1] + if Path(server_port_path).exists(): + Path( + server_port_path + ).unlink() # the vsock server program doesn't have reuseaddr + + echo_server = Popen([server_bin_path, "serve", port, "af_unix", server_port_path]) + + try: + # Give the server program bit of time to create the socket + for attempt in Retrying( + wait=wait_fixed(0.2), + stop=stop_after_attempt(3), + reraise=True, + ): + with attempt: + assert Path(server_port_path).exists() + + # Link the listening Unix socket into the VM's jail, so that + # Firecracker can connect to it. + vm.create_jailed_resource(server_port_path) + + # Increase maximum process count for the ssh service. + # Avoids: "bash: fork: retry: Resource temporarily unavailable" + # Needed to execute the bash script that tests for concurrent + # vsock guest initiated connections. + vm.ssh.check_output( + "echo 1024 > /sys/fs/cgroup/system.slice/ssh.service/pids.max" + ) + + # Build the guest worker sub-command. + # `vsock_helper` will read the blob file from STDIN and send the echo + # server response to STDOUT. This response is then hashed, and the + # hash is compared against `blob_hash` (computed on the host). This + # comparison sets the exit status of the worker command. + worker_cmd = "hash=$(" + worker_cmd += "cat {}".format(blob_path) + worker_cmd += " | /tmp/vsock_helper echo 2 {} seqpacket".format( + ECHO_SERVER_PORT + ) + worker_cmd += " | md5sum | cut -f1 -d\\ " + worker_cmd += ")" + worker_cmd += ' && [[ "$hash" = "{}" ]]'.format(blob_hash) + + # Run `TEST_CONNECTION_COUNT` concurrent workers, using the above + # worker sub-command. + # If any worker fails, this command will fail. If all worker sub-commands + # succeed, this will also succeed. + cmd = 'workers="";' + cmd += "for i in $(seq 1 {}); do".format(TEST_CONNECTION_COUNT) + cmd += " ({})& ".format(worker_cmd) + cmd += ' workers="$workers $!";' + cmd += "done;" + cmd += "for w in $workers; do wait $w || (wait; exit 1); done" + + vm.ssh.check_output(cmd) + finally: + echo_server.terminate() + echo_server.wait() + + def check_guest_connections(vm, server_port_path, blob_path, blob_hash): """Test guest-initiated connections. @@ -173,7 +259,7 @@ def check_guest_connections(vm, server_port_path, blob_path, blob_hash): # comparison sets the exit status of the worker command. worker_cmd = "hash=$(" worker_cmd += "cat {}".format(blob_path) - worker_cmd += " | /tmp/vsock_helper echo 2 {}".format(ECHO_SERVER_PORT) + worker_cmd += " | /tmp/vsock_helper echo 2 {} stream".format(ECHO_SERVER_PORT) worker_cmd += " | md5sum | cut -f1 -d\\ " worker_cmd += ")" worker_cmd += ' && [[ "$hash" = "{}" ]]'.format(blob_hash) @@ -202,9 +288,9 @@ def make_host_port_path(uds_path, port): return "{}_{}".format(uds_path, port) -def _vsock_connect_to_guest(uds_path, port): +def _vsock_connect_to_guest(uds_path, port, vsock_type=SOCK_STREAM): """Return a Unix socket, connected to the guest vsock port `port`.""" - sock = socket(AF_UNIX, SOCK_STREAM) + sock = socket(AF_UNIX, vsock_type) sock.connect(uds_path) buf = bytearray("CONNECT {}\n".format(port).encode("utf-8")) @@ -216,14 +302,19 @@ def _vsock_connect_to_guest(uds_path, port): return sock -def _copy_vsock_data_to_guest(ssh_connection, blob_path, vm_blob_path, vsock_helper): +def _copy_vsock_data_to_guest( + ssh_connection, blob_path, vm_blob_path, vsock_helper=None, vsock_seq_server=None +): # Copy the data file and a vsock helper to the guest. cmd = "mkdir -p /tmp/vsock" ecode, _, _ = ssh_connection.run(cmd) assert ecode == 0, "Failed to set up tmpfs drive on the guest." + if vsock_helper: + ssh_connection.scp_put(vsock_helper, "/tmp/vsock_helper") + if vsock_seq_server: + ssh_connection.scp_put(vsock_seq_server, "/tmp/vsock_seq_server") - ssh_connection.scp_put(vsock_helper, "/tmp/vsock_helper") ssh_connection.scp_put(blob_path, vm_blob_path) diff --git a/tests/host_tools/vsock_helper.c b/tests/host_tools/vsock_helper.c index 9368ffd79d6..a2518207832 100644 --- a/tests/host_tools/vsock_helper.c +++ b/tests/host_tools/vsock_helper.c @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -23,13 +24,15 @@ #define BUF_SIZE (16 * 1024) #define SERVER_ACCEPT_BACKLOG 128 +volatile int connection_socket; int print_usage() { - fprintf(stderr, "Usage: ./vsock-helper echo \n"); + fprintf(stderr, "Usage: ./vsock-helper echo [stream|seqpacket]\n"); fprintf(stderr, "\n"); fprintf(stderr, " echo connect to an echo server, listening on CID:port.\n"); fprintf(stderr, " STDIN will be piped through to the echo server, and\n"); fprintf(stderr, " data coming from the server will pe sent to STDOUT.\n"); + fprintf(stderr, " stream|seqpacket socket type (default: stream)\n"); fprintf(stderr, "\n"); return -1; } @@ -53,9 +56,9 @@ int xfer(int src_fd, int dst_fd) { } -int run_echo(uint32_t cid, uint32_t port) { +int run_echo(uint32_t cid, uint32_t port, int sock_type) { - int sock = socket(AF_VSOCK, SOCK_STREAM, 0); + int sock = socket(AF_VSOCK, sock_type, 0); if (sock < 0) { perror("socket()"); return -1; @@ -71,6 +74,8 @@ int run_echo(uint32_t cid, uint32_t port) { return -1; } + connection_socket = sock; + for (;;) { int ping_cnt = xfer(STDIN_FILENO, sock); if (!ping_cnt) break; @@ -87,15 +92,21 @@ int run_echo(uint32_t cid, uint32_t port) { return close(sock); } +void stop_server_loop(int sig) { + close(connection_socket); +} + int main(int argc, char **argv) { + signal(SIGTERM, stop_server_loop); + signal(SIGINT, stop_server_loop); if (argc < 3) { return print_usage(); } if (strcmp(argv[1], "echo") == 0) { - if (argc != 4) { + if (argc < 4 || argc > 5) { return print_usage(); } uint32_t cid = atoi(argv[2]); @@ -103,7 +114,19 @@ int main(int argc, char **argv) { if (!cid || !port) { return print_usage(); } - return run_echo(cid, port); + + int sock_type = SOCK_STREAM; + if (argc == 5) { + if (strcmp(argv[4], "seqpacket") == 0) { + sock_type = SOCK_SEQPACKET; + } else if (strcmp(argv[4], "stream") == 0) { + sock_type = SOCK_STREAM; + } else { + return print_usage(); + } + } + + return run_echo(cid, port, sock_type); } return print_usage(); diff --git a/tests/host_tools/vsock_seq_server.c b/tests/host_tools/vsock_seq_server.c new file mode 100644 index 00000000000..003f0abace7 --- /dev/null +++ b/tests/host_tools/vsock_seq_server.c @@ -0,0 +1,196 @@ +// Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define BUF_SIZE 16384 + +volatile sig_atomic_t running = 1; +volatile int listener_sockfd; + +void log_info(const char *fmt, ...) { + va_list args; + va_start(args, fmt); + fprintf(stdout, "[INFO] "); + vfprintf(stdout, fmt, args); + fprintf(stdout, "\n"); + va_end(args); + fflush(stdout); +} + +void log_error(const char *fmt, ...) { + va_list args; + va_start(args, fmt); + fprintf(stderr, "[ERROR] "); + vfprintf(stderr, fmt, args); + fprintf(stderr, "\n"); + va_end(args); +} + +int print_usage() { + fprintf(stderr, "Usage: ./vsock_seq_server serve [af_vsock|af_unix] [path]\n"); + fprintf(stderr, "\n"); + fprintf(stderr, " serve start a SEQPACKET echo server on :port.\n"); + fprintf(stderr, " Data received from the client is echoed back.\n"); + fprintf(stderr, " af_vsock|af_unix socket family to use (default: af_vsock)\n"); + fprintf(stderr, " path socket path for af_unix (optional; omit for no address)\n"); + fprintf(stderr, "\n"); + return -1; +} + +void *handle_conn(void *connfd_ptr) { + int connfd = *(int *)(connfd_ptr); + free(connfd_ptr); + char buf[BUF_SIZE]; + ssize_t n; + + // echo back whatever you received into the connection again + while ((n = recv(connfd, buf, sizeof(buf), 0)) > 0) { + log_info("received %zd bytes", n); + + if (send(connfd, buf, n, 0) < 0) { + log_error("send: %s", strerror(errno)); + break; + } + } + + if (n == 0) { + log_info("connection closed by peer"); + } + else if (n < 0) { + log_error("recv: %s (errno=%d)", strerror(errno), errno); + } + + close(connfd); + return NULL; +} + +int run_seq_server(int port, int family, const char *path) +{ + int sockfd, connfd; + + sockfd = socket(family, SOCK_SEQPACKET, 0); + if (sockfd < 0) { + log_error("socket: %s", strerror(errno)); + exit(1); + } + + listener_sockfd = sockfd; + + if (family == AF_VSOCK) { + struct sockaddr_vm addr; + memset(&addr, 0, sizeof(addr)); + addr.svm_family = AF_VSOCK; + addr.svm_port = port; + addr.svm_cid = VMADDR_CID_ANY; + + if (bind(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + log_error("bind: %s", strerror(errno)); + close(sockfd); + exit(1); + } + } else if (family == AF_UNIX && path != NULL) { + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1); + + if (bind(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + log_error("bind: %s", strerror(errno)); + close(sockfd); + exit(1); + } + } + + if (listen(sockfd, 5) < 0) { + log_error("listen: %s", strerror(errno)); + close(sockfd); + exit(1); + } + + log_info("SEQPACKET server listening on port %d (family=%s%s%s)", + port, + family == AF_VSOCK ? "af_vsock" : "af_unix", + path ? " path=" : "", + path ? path : ""); + + while (running) { + connfd = accept(sockfd, NULL, NULL); + if (connfd < 0) { + if (errno == EINTR) break; // accept interrupted by signal + log_error("accept: %s", strerror(errno)); + exit(1); + } + + log_info("connection accepted (fd=%d)", connfd); + + int *connfd_ptr = malloc(sizeof(int)); + *connfd_ptr = connfd; + + pthread_t tid; + pthread_create(&tid, NULL, handle_conn, connfd_ptr); + pthread_detach(tid); + }; + + close(sockfd); + return 0; +} + +void stop_server_loop(int sig) { + running = 0; + close(listener_sockfd); +} + +int main(int argc, char **argv) { + signal(SIGTERM, stop_server_loop); + signal(SIGINT, stop_server_loop); + + if (argc < 2) { + return print_usage(); + } + + if (strcmp(argv[1], "serve") == 0) { + if (argc < 3) { + return print_usage(); + } + + int port = atoi(argv[2]); + if (!port) { + return print_usage(); + } + + int family = AF_VSOCK; + const char *path = NULL; + + if (argc >= 4) { + if (strcmp(argv[3], "af_unix") == 0) { + family = AF_UNIX; + } else if (strcmp(argv[3], "af_vsock") == 0) { + family = AF_VSOCK; + } else { + return print_usage(); + } + } + + if (argc >= 5) { + path = argv[4]; + } + + return run_seq_server(port, family, path); + } + + return print_usage(); +} diff --git a/tests/integration_tests/functional/test_api.py b/tests/integration_tests/functional/test_api.py index cd3b487e10e..b18aa34445e 100644 --- a/tests/integration_tests/functional/test_api.py +++ b/tests/integration_tests/functional/test_api.py @@ -965,18 +965,30 @@ def test_api_vsock(uvm_nano): """ vm = uvm_nano # Create a vsock device. - vm.api.vsock.put(guest_cid=15, uds_path="vsock.sock") + vm.api.vsock.put(guest_cid=15, uds_path="vsock.sock", vsock_type="stream") + + # Updating an existing vsock is currently fine. Even changing its type. + vm.api.vsock.put( + guest_cid=166, + uds_path="vsock.sock", + vsock_type="seqpacket", + conn_buffer_size=4096, + ) - # Updating an existing vsock is currently fine. - vm.api.vsock.put(guest_cid=166, uds_path="vsock.sock") + # Revert it back to stream. + vm.api.vsock.put(guest_cid=166, uds_path="vsock.sock", vsock_type="stream") # Check PUT request. Although vsock_id is deprecated, it must still work. - response = vm.api.vsock.put(vsock_id="vsock1", guest_cid=15, uds_path="vsock.sock") + response = vm.api.vsock.put( + vsock_id="vsock1", guest_cid=15, uds_path="vsock.sock", vsock_type="stream" + ) assert response.headers["deprecation"] # Updating an existing vsock is currently fine even with deprecated # `vsock_id`. - response = vm.api.vsock.put(vsock_id="vsock1", guest_cid=166, uds_path="vsock.sock") + response = vm.api.vsock.put( + vsock_id="vsock1", guest_cid=166, uds_path="vsock.sock", vsock_type="stream" + ) assert response.headers["deprecation"] # No other vsock action is allowed after booting the VM. @@ -984,7 +996,11 @@ def test_api_vsock(uvm_nano): # Updating an existing vsock should not be fine at this point. with pytest.raises(RuntimeError): - vm.api.vsock.put(guest_cid=17, uds_path="vsock.sock") + vm.api.vsock.put(guest_cid=17, uds_path="vsock.sock", vsock_type="stream") + + # Changing the type of a vsock device should error at this point. + with pytest.raises(RuntimeError): + vm.api.vsock.put(guest_cid=17, uds_path="vsock.sock", vsock_type="seqpacket") def test_api_entropy(uvm_plain): @@ -1342,8 +1358,12 @@ def test_get_full_config_after_restoring_snapshot(microvm_factory, uvm_nano): } # Add a vsock device. - uvm_nano.api.vsock.put(guest_cid=15, uds_path="vsock.sock") - setup_cfg["vsock"] = {"guest_cid": 15, "uds_path": "vsock.sock"} + uvm_nano.api.vsock.put(guest_cid=15, uds_path="vsock.sock", vsock_type="stream") + setup_cfg["vsock"] = { + "guest_cid": 15, + "uds_path": "vsock.sock", + "vsock_type": "stream", + } setup_cfg["memory-hotplug"] = { "total_size_mib": 1024, @@ -1478,9 +1498,14 @@ def test_get_full_config(uvm_plain): } # Add a vsock device. - response = test_microvm.api.vsock.put(guest_cid=15, uds_path="vsock.sock") - expected_cfg["vsock"] = {"guest_cid": 15, "uds_path": "vsock.sock"} - + response = test_microvm.api.vsock.put( + guest_cid=15, uds_path="vsock.sock", vsock_type="stream" + ) + expected_cfg["vsock"] = { + "guest_cid": 15, + "uds_path": "vsock.sock", + "vsock_type": "stream", + } # Add hot-pluggable memory. expected_cfg["memory-hotplug"] = { "total_size_mib": 1024, @@ -1619,7 +1644,7 @@ def test_negative_snapshot_load_api(microvm_factory): ) # API request without `mem_backend` or `mem_file_path` should fail. - err_msg = "missing field: either `mem_backend` or " "`mem_file_path` is required" + err_msg = "missing field: either `mem_backend` or `mem_file_path` is required" with pytest.raises(RuntimeError, match=err_msg): vm.api.snapshot_load.put(snapshot_path="foo") diff --git a/tests/integration_tests/functional/test_snapshot_basic.py b/tests/integration_tests/functional/test_snapshot_basic.py index 9a425c4cec5..032c2fc757f 100644 --- a/tests/integration_tests/functional/test_snapshot_basic.py +++ b/tests/integration_tests/functional/test_snapshot_basic.py @@ -140,7 +140,9 @@ def test_cycled_snapshot_restore( ) vm.set_cpu_template(cpu_template_any) vm.add_net_iface() - vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path=VSOCK_UDS_PATH) + vm.api.vsock.put( + vsock_id="vsock0", guest_cid=3, uds_path=VSOCK_UDS_PATH, vsock_type="stream" + ) vm.start() vm_blob_path = "/tmp/vsock/test.blob" @@ -597,7 +599,9 @@ def test_snapshot_rename_vsock( """ vm = uvm_nano - vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path="/v.sock1") + vm.api.vsock.put( + vsock_id="vsock0", guest_cid=3, uds_path="/v.sock1", vsock_type="stream" + ) vm.add_net_iface() vm.start() diff --git a/tests/integration_tests/functional/test_snapshot_phase1.py b/tests/integration_tests/functional/test_snapshot_phase1.py index 89bed92ee8d..063d87e731a 100644 --- a/tests/integration_tests/functional/test_snapshot_phase1.py +++ b/tests/integration_tests/functional/test_snapshot_phase1.py @@ -49,7 +49,9 @@ def test_snapshot_phase1( for i in range(4): vm.add_net_iface() # Add a vsock device - vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path="/v.sock") + vm.api.vsock.put( + vsock_id="vsock0", guest_cid=3, uds_path="/v.sock", vsock_type="stream" + ) # Add MMDS configure_mmds(vm, ["eth3"], version="V2") # Add a memory balloon. diff --git a/tests/integration_tests/functional/test_vsock.py b/tests/integration_tests/functional/test_vsock.py index 64d9e1199f0..0b152830d26 100644 --- a/tests/integration_tests/functional/test_vsock.py +++ b/tests/integration_tests/functional/test_vsock.py @@ -13,11 +13,14 @@ guest-initiated connections. """ +import os import os.path import subprocess import time from pathlib import Path +from socket import SOCK_SEQPACKET, SOCK_STREAM from socket import timeout as SocketTimeout +from threading import Thread import pytest @@ -26,12 +29,15 @@ VSOCK_UDS_PATH, HostEchoWorker, _copy_vsock_data_to_guest, + _vsock_connect_to_guest, check_guest_connections, + check_guest_connections_seqpacket, check_host_connections, check_vsock_device, make_blob, make_host_port_path, start_guest_echo_server, + start_seqpacket_echo_server, ) from host_tools.fcmetrics import validate_fc_metrics @@ -51,7 +57,12 @@ def test_vsock(uvm_plain_any, bin_vsock_path, test_fc_session_root_path): vm.basic_config() vm.add_net_iface() - vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path=f"/{VSOCK_UDS_PATH}") + vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path=f"/{VSOCK_UDS_PATH}", + vsock_type="stream", + ) vm.start() check_vsock_device(vm, bin_vsock_path, test_fc_session_root_path, vm.ssh) @@ -59,7 +70,7 @@ def test_vsock(uvm_plain_any, bin_vsock_path, test_fc_session_root_path): validate_fc_metrics(metrics) -def negative_test_host_connections(vm, blob_path, blob_hash): +def negative_test_host_connections(vm, blob_path, blob_hash, vsock_type): """Negative test for host-initiated connections. This will start a daemonized echo server on the guest VM, and then spawn @@ -71,7 +82,7 @@ def negative_test_host_connections(vm, blob_path, blob_hash): workers = [] for _ in range(NEGATIVE_TEST_CONNECTION_COUNT): - worker = HostEchoWorker(uds_path, blob_path) + worker = HostEchoWorker(uds_path, blob_path, vsock_type) workers.append(worker) worker.start() @@ -112,7 +123,12 @@ def test_vsock_epipe(uvm_plain_any, bin_vsock_path, test_fc_session_root_path): vm.spawn() vm.basic_config() vm.add_net_iface() - vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path=f"/{VSOCK_UDS_PATH}") + vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path=f"/{VSOCK_UDS_PATH}", + vsock_type="stream", + ) vm.start() # Generate the random data blob file, 20MB @@ -125,7 +141,7 @@ def test_vsock_epipe(uvm_plain_any, bin_vsock_path, test_fc_session_root_path): # Negative test for host-initiated connections that # are closed with in flight data. - negative_test_host_connections(vm, blob_path, blob_hash) + negative_test_host_connections(vm, blob_path, blob_hash, SOCK_STREAM) metrics = vm.flush_metrics() validate_fc_metrics(metrics) @@ -152,7 +168,12 @@ def test_vsock_transport_reset_h2g( test_vm.spawn() test_vm.basic_config(vcpu_count=2, mem_size_mib=256) test_vm.add_net_iface() - test_vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path=f"/{VSOCK_UDS_PATH}") + test_vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path=f"/{VSOCK_UDS_PATH}", + vsock_type="stream", + ) test_vm.start() # Generate the random data blob file. @@ -194,6 +215,7 @@ def test_vsock_transport_reset_h2g( assert ( response == b"" ), f"Connection not closed: response received '{response.decode('utf-8')}'" + except (SocketTimeout, ConnectionResetError, BrokenPipeError): pass @@ -225,7 +247,12 @@ def test_vsock_transport_reset_g2h(uvm_plain_any, microvm_factory): test_vm.spawn() test_vm.basic_config(vcpu_count=2, mem_size_mib=256) test_vm.add_net_iface() - test_vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path=f"/{VSOCK_UDS_PATH}") + test_vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path=f"/{VSOCK_UDS_PATH}", + vsock_type="stream", + ) test_vm.start() # Create snapshot and terminate a VM. @@ -286,6 +313,123 @@ def test_vsock_transport_reset_g2h(uvm_plain_any, microvm_factory): new_vm.kill() +def test_vsock_seqpacket_h2g( + uvm_plain_6_1, bin_vsock_seqpacket_listener_path, test_fc_session_root_path +): + """Test host-to-guest vsock seqpacket connections.""" + vm = uvm_plain_6_1 + vm.spawn() + vm.basic_config() + vm.add_net_iface() + vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path=f"/{VSOCK_UDS_PATH}", + vsock_type="seqpacket", + conn_buffer_size=16 * 1024, + ) + vm.start() + + blob_path, blob_hash = make_blob(test_fc_session_root_path, 16 * 1024) + vm_blob_path = "/tmp/vsock/test.blob" + + _copy_vsock_data_to_guest( + vm.ssh, + blob_path, + vm_blob_path, + vsock_seq_server=bin_vsock_seqpacket_listener_path, + ) + path = start_seqpacket_echo_server(vm) + + check_host_connections(path, blob_path, blob_hash, SOCK_SEQPACKET) + metrics = vm.flush_metrics() + validate_fc_metrics(metrics) + + +def test_vsock_seqpacket_g2h( + uvm_plain_6_1, + bin_vsock_seqpacket_listener_path, + bin_vsock_path, + test_fc_session_root_path, +): + """Test guest-to-host vsock seqpacket connections.""" + vm = uvm_plain_6_1 + vm.spawn() + vm.basic_config() + vm.add_net_iface() + vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path=f"/{VSOCK_UDS_PATH}", + vsock_type="seqpacket", + conn_buffer_size=16 * 1024, + ) + vm.start() + + blob_path, blob_hash = make_blob(test_fc_session_root_path, 16 * 1024) + vm_blob_path = "/tmp/vsock/test.blob" + + _copy_vsock_data_to_guest(vm.ssh, blob_path, vm_blob_path, bin_vsock_path) + + path = os.path.join(vm.path, make_host_port_path(VSOCK_UDS_PATH, ECHO_SERVER_PORT)) + check_guest_connections_seqpacket( + vm, path, bin_vsock_seqpacket_listener_path, vm_blob_path, blob_hash + ) + + metrics = vm.flush_metrics() + validate_fc_metrics(metrics) + + +def test_vsock_seqpacket_h2g_overflow( + uvm_plain_6_1, bin_vsock_seqpacket_listener_path, test_fc_session_root_path +): + """Test that sending a message larger than conn_buffer_size errors.""" + conn_buffer_size = 16 * 1024 + + vm = uvm_plain_6_1 + vm.spawn() + vm.basic_config() + vm.add_net_iface() + vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path=f"/{VSOCK_UDS_PATH}", + vsock_type="seqpacket", + conn_buffer_size=conn_buffer_size, + ) + vm.start() + + blob_path, _ = make_blob(test_fc_session_root_path, 1024) + vm_blob_path = "/tmp/vsock/test.blob" + _copy_vsock_data_to_guest( + vm.ssh, + blob_path, + vm_blob_path, + vsock_seq_server=bin_vsock_seqpacket_listener_path, + ) + path = start_seqpacket_echo_server(vm) + + worker_error = None + + def worker(): + nonlocal worker_error + try: + sock = _vsock_connect_to_guest(path, ECHO_SERVER_PORT, SOCK_SEQPACKET) + oversized_msg = os.urandom(conn_buffer_size + 1) + sock.send(oversized_msg) + sock.recv(conn_buffer_size + 1) + except OSError as err: + worker_error = err + + t = Thread(target=worker) + t.start() + t.join() + + assert ( + worker_error is not None + ), "Expected an error when sending message larger than conn_buffer_size" + + def test_vsock_after_override( uvm_plain_any, microvm_factory, bin_vsock_path, test_fc_session_root_path ): diff --git a/tests/integration_tests/performance/test_snapshot.py b/tests/integration_tests/performance/test_snapshot.py index 6de1b05b111..ba29816fc4f 100644 --- a/tests/integration_tests/performance/test_snapshot.py +++ b/tests/integration_tests/performance/test_snapshot.py @@ -74,7 +74,9 @@ def boot_vm(self, microvm_factory, guest_kernel, rootfs, pci_enabled) -> Microvm vm.api.balloon.put( amount_mib=0, deflate_on_oom=True, stats_polling_interval_s=1 ) - vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path="/v.sock") + vm.api.vsock.put( + vsock_id="vsock0", guest_cid=3, uds_path="/v.sock", vsock_type="stream" + ) vm.start() diff --git a/tests/integration_tests/performance/test_vsock.py b/tests/integration_tests/performance/test_vsock.py index fa4c3a5abb5..a7f932d0cfd 100644 --- a/tests/integration_tests/performance/test_vsock.py +++ b/tests/integration_tests/performance/test_vsock.py @@ -92,7 +92,12 @@ def test_vsock_throughput( vm.basic_config(vcpu_count=vcpus, mem_size_mib=mem_size_mib) vm.add_net_iface() # Create a vsock device - vm.api.vsock.put(vsock_id="vsock0", guest_cid=3, uds_path="/" + VSOCK_UDS_PATH) + vm.api.vsock.put( + vsock_id="vsock0", + guest_cid=3, + uds_path="/" + VSOCK_UDS_PATH, + vsock_type="stream", + ) vm.start() metrics.set_dimensions(