diff --git a/src/devices/vmcall_raw/src/transport/vmcall.rs b/src/devices/vmcall_raw/src/transport/vmcall.rs index 8c81c1fa..7835e6da 100644 --- a/src/devices/vmcall_raw/src/transport/vmcall.rs +++ b/src/devices/vmcall_raw/src/transport/vmcall.rs @@ -82,6 +82,53 @@ pub fn vmcall_raw_transport_can_recv() -> Result { Ok(false) } +/// Poll the VMM completion status for a pending VMCALL operation. +/// +/// Shared logic for both send and receive: checks the interrupt flag, +/// parses data_status from the shared buffer header, and immediately +/// consumes the flag so work only happens on a real interrupt. +/// +/// Returns `Poll::Pending` if the operation is still in flight, +/// `Poll::Ready(Ok((payload, data_length)))` on success, or +/// `Poll::Ready(Err(...))` on VMM cancellation or missing context. +/// +/// NOTE: `data_length` is owned by MigTD when status=0; on the send path +/// the post-completion value is not meaningful and must be discarded. +fn poll_vmcall_completion<'a>( + mig_request_id: u64, + data_buffer: &'a mut [u8], +) -> Poll> { + if let Some(flag) = VMCALL_MIG_CONTEXT_FLAGS.lock().get(&mig_request_id) { + if flag.load(Ordering::SeqCst) { + flag.store(false, Ordering::SeqCst); + } else { + return Poll::Pending; + } + } else { + return Poll::Ready(Err(VmcallRawError::Illegal)); + } + + let (payload, data_status, data_length) = process_buffer(data_buffer); + let status_bytes = data_status.to_le_bytes(); + + if status_bytes[0] == TDX_VMCALL_STATUS_NOT_COMPLETED + && status_bytes[1] == TDX_VMCALL_STATUS_VMM_CANCEL + { + log::warn!( + "VMM canceled migration session (migration_request_id={}, data_status=0x{:x})\n", + mig_request_id, + data_status + ); + return Poll::Ready(Err(VmcallRawError::VmmCanceled)); + } + + if status_bytes[0] != TDX_VMCALL_VMM_SUCCESS { + return Poll::Pending; + } + + Poll::Ready(Ok((payload, data_length))) +} + async fn vmcall_service_migtd_send( mig_request_id: u64, data_buffer: &mut [u8], @@ -89,36 +136,14 @@ async fn vmcall_service_migtd_send( tdx::tdvmcall_migtd_send(mig_request_id, data_buffer, VMCALL_VECTOR) .map_err(|_e| VmcallRawError::TdVmcallErr)?; - poll_fn(|_cx| -> Poll> { - if let Some(flag) = VMCALL_MIG_CONTEXT_FLAGS.lock().get(&mig_request_id) { - if flag.load(Ordering::SeqCst) { - flag.store(false, Ordering::SeqCst); - } else { - return Poll::Pending; - } - } else { - let _ = Poll::Ready(Err::, _>(VmcallRawError::Illegal)); - } - - let (_send_buf, data_status, data_length) = process_buffer(data_buffer); - let data_status_bytes = data_status.to_le_bytes(); - - if data_status_bytes[0] != TDX_VMCALL_VMM_SUCCESS { - if data_status_bytes[0] == TDX_VMCALL_STATUS_NOT_COMPLETED - && data_status_bytes[1] == TDX_VMCALL_STATUS_VMM_CANCEL - { - log::warn!( - "VMM canceled migration session (migration_request_id={}, data_status=0x{:x})\n", - mig_request_id, - data_status - ); - return Poll::Ready(Err(VmcallRawError::VmmCanceled)); - } - return Poll::Pending; - } - - Poll::Ready(Ok(data_length as usize)) - }) + // Discard `data_length`: not a meaningful VMM-reported value on send. + poll_fn( + |_cx| match poll_vmcall_completion(mig_request_id, data_buffer) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok((_payload, _data_length))) => Poll::Ready(Ok(0usize)), + }, + ) .await } @@ -137,45 +162,31 @@ async fn vmcall_service_migtd_receive( stream: &VmcallRaw, data_buffer: &mut [u8], ) -> Result, VmcallRawError> { - tdx::tdvmcall_migtd_receive(stream.addr.transport_context(), data_buffer, VMCALL_VECTOR) + let mig_request_id = stream.addr.transport_context(); + tdx::tdvmcall_migtd_receive(mig_request_id, data_buffer, VMCALL_VECTOR) .map_err(|_e| VmcallRawError::TdVmcallErr)?; - poll_fn(|_cx| -> Poll, VmcallRawError>> { - let mig_request_id = stream.addr.transport_context(); - if let Some(flag) = VMCALL_MIG_CONTEXT_FLAGS.lock().get(&mig_request_id) { - if flag.load(Ordering::SeqCst) { - flag.store(false, Ordering::SeqCst); - } else { - return Poll::Pending; - } - } else { - let _ = Poll::Ready(Err::, _>(VmcallRawError::Illegal)); - } - - let (response_buf, data_status, data_length) = process_buffer(data_buffer); - let data_status_bytes = data_status.to_le_bytes(); - - if data_status_bytes[0] != TDX_VMCALL_VMM_SUCCESS { - if data_status_bytes[0] == TDX_VMCALL_STATUS_NOT_COMPLETED - && data_status_bytes[1] == TDX_VMCALL_STATUS_VMM_CANCEL - { - log::warn!( - "VMM canceled migration session (migration_request_id={}, data_status=0x{:x})\n", - mig_request_id, - data_status - ); - return Poll::Ready(Err(VmcallRawError::VmmCanceled)); + poll_fn( + |_cx| match poll_vmcall_completion(mig_request_id, data_buffer) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok((payload, data_length))) => { + // Reject zero-length success as malformed: defense against a + // hostile VMM that would otherwise stall the caller. + if data_length == 0 { + log::error!( + "vmcall_service_migtd_receive: VMM reported success with zero-length payload (mig_request_id={})\n", + mig_request_id + ); + return Poll::Ready(Err(VmcallRawError::Malformed)); + } + match payload.get(..data_length as usize) { + Some(slice) => Poll::Ready(Ok(slice.to_vec())), + None => Poll::Ready(Err(VmcallRawError::TdVmcallErr)), + } } - return Poll::Pending; - } - - let data = match response_buf.get(..data_length as usize) { - Some(slice) => slice.to_vec(), - None => return Poll::Ready(Err(VmcallRawError::TdVmcallErr)), - }; - - Poll::Ready(Ok(data)) - }) + }, + ) .await }