diff --git a/Cargo.lock b/Cargo.lock index 02eb1d37d6..dfd1a0d082 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2553,8 +2553,6 @@ dependencies = [ [[package]] name = "x86_64" version = "0.15.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7841fa0098ceb15c567d93d3fae292c49e10a7662b4936d5f6a9728594555ba" dependencies = [ "bit_field", "bitflags 2.10.0", diff --git a/Cargo.toml b/Cargo.toml index 03e3ba8d6f..1a36762774 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -213,6 +213,7 @@ exclude = [ # FIXME: remove once merged: https://github.com/rcore-os/trapframe-rs/pull/16 trapframe = { git = "https://github.com/hermit-os/trapframe-rs", branch = "global_asm" } safe-mmio = { git = "https://github.com/hermit-os/safe-mmio", branch = "be" } +x86_64 ={ path = "../../x86_64" } [profile.profiling] inherits = "release" diff --git a/src/arch/x86_64/kernel/core_local.rs b/src/arch/x86_64/kernel/core_local.rs index 1aa29965e8..39f2aaeed1 100644 --- a/src/arch/x86_64/kernel/core_local.rs +++ b/src/arch/x86_64/kernel/core_local.rs @@ -78,7 +78,7 @@ impl CoreLocal { }; this.this = ptr::from_ref(this); - GsBase::write(VirtAddr::from_ptr(this)); + unsafe { GsBase::write(VirtAddr::from_ptr(this)); } } #[inline] diff --git a/src/arch/x86_64/kernel/processor.rs b/src/arch/x86_64/kernel/processor.rs index 07d4014802..b7047c361f 100644 --- a/src/arch/x86_64/kernel/processor.rs +++ b/src/arch/x86_64/kernel/processor.rs @@ -1154,7 +1154,7 @@ pub fn writefs(fs: usize) { FS::write_base(base); } } else { - FsBase::write(base); + unsafe { FsBase::write(base); } } } @@ -1166,7 +1166,7 @@ pub fn writegs(gs: usize) { GS::write_base(base); } } else { - GsBase::write(base); + unsafe { GsBase::write(base); } } } diff --git a/src/arch/x86_64/mm/paging.rs b/src/arch/x86_64/mm/paging.rs index 96d1e50469..ab6d4e9ccf 100644 --- a/src/arch/x86_64/mm/paging.rs +++ b/src/arch/x86_64/mm/paging.rs @@ -112,7 +112,7 @@ pub unsafe fn identity_mapped_page_table() -> OffsetPageTable<'static> { ptr::with_exposed_provenance_mut::(level_4_table_addr.try_into().unwrap()); unsafe { let level_4_table = level_4_table_ptr.as_mut().unwrap(); - OffsetPageTable::new(level_4_table, x86_64::addr::VirtAddr::new(0x0)) + OffsetPageTable::from_phys_offset(level_4_table, x86_64::addr::VirtAddr::new(0x0)) } } @@ -179,7 +179,7 @@ pub fn map( for (page, frame) in pages.zip(frames) { // TODO: Require explicit unmaps let unmap = mapper.unmap(page); - if let Ok((_frame, flush)) = unmap { + if let Ok((_frame, flags, flush)) = unmap { unmapped = true; flush.flush(); debug!("Had to unmap page {page:?} before mapping."); @@ -265,7 +265,7 @@ where for page in range { let unmap_result = unsafe { identity_mapped_page_table() }.unmap(page); match unmap_result { - Ok((_frame, flush)) => flush.flush(), + Ok((_frame, flags, flush)) => flush.flush(), // FIXME: Some sentinel pages around stacks are supposed to be unmapped. // We should handle this case there instead of here. Err(UnmapError::PageNotMapped) => { @@ -365,642 +365,6 @@ fn make_p4_writable() { pub unsafe fn log_page_tables() { use log::Level; - use self::mapped_page_range_display::OffsetPageTableExt; - - if !log_enabled!(Level::Trace) { - return; - } - let page_table = unsafe { identity_mapped_page_table() }; - trace!("Page tables:\n{}", page_table.display()); -} - -pub mod mapped_page_range_display { - use core::fmt::{self, Write}; - - use x86_64::structures::paging::mapper::PageTableFrameMapping; - use x86_64::structures::paging::{MappedPageTable, OffsetPageTable, PageSize}; - - use super::mapped_page_table_iter::{ - self, MappedPageRangeInclusive, MappedPageRangeInclusiveItem, - MappedPageTableRangeInclusiveIter, - }; - use super::offset_page_table::PhysOffset; - - #[expect(dead_code)] - pub trait MappedPageTableExt { - fn display(&self) -> MappedPageTableDisplay<'_, &P>; - } - - impl MappedPageTableExt

for MappedPageTable<'_, P> { - fn display(&self) -> MappedPageTableDisplay<'_, &P> { - MappedPageTableDisplay { - inner: mapped_page_table_iter::mapped_page_table_range_iter(self), - } - } - } - - pub trait OffsetPageTableExt { - fn display(&self) -> MappedPageTableDisplay<'_, PhysOffset>; - } - - impl OffsetPageTableExt for OffsetPageTable<'_> { - fn display(&self) -> MappedPageTableDisplay<'_, PhysOffset> { - MappedPageTableDisplay { - inner: mapped_page_table_iter::offset_page_table_range_iter(self), - } - } - } - - pub struct MappedPageTableDisplay<'a, P: PageTableFrameMapping + Clone> { - inner: MappedPageTableRangeInclusiveIter<'a, P>, - } - - impl fmt::Display for MappedPageTableDisplay<'_, P> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut has_fields = false; - - for mapped_page_range in self.inner.clone() { - if has_fields { - f.write_char('\n')?; - } - write!(f, "{}", mapped_page_range.display())?; - - has_fields = true; - } - - Ok(()) - } - } - - pub trait MappedPageRangeInclusiveItemExt { - fn display(&self) -> MappedPageRangeInclusiveItemDisplay<'_>; - } - - impl MappedPageRangeInclusiveItemExt for MappedPageRangeInclusiveItem { - fn display(&self) -> MappedPageRangeInclusiveItemDisplay<'_> { - MappedPageRangeInclusiveItemDisplay { inner: self } - } - } - - pub struct MappedPageRangeInclusiveItemDisplay<'a> { - inner: &'a MappedPageRangeInclusiveItem, - } - - impl fmt::Display for MappedPageRangeInclusiveItemDisplay<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.inner { - MappedPageRangeInclusiveItem::Size4KiB(range) => range.display().fmt(f), - MappedPageRangeInclusiveItem::Size2MiB(range) => range.display().fmt(f), - MappedPageRangeInclusiveItem::Size1GiB(range) => range.display().fmt(f), - } - } - } - - pub trait MappedPageRangeInclusiveExt { - fn display(&self) -> MappedPageRangeInclusiveDisplay<'_, S>; - } - - impl MappedPageRangeInclusiveExt for MappedPageRangeInclusive { - fn display(&self) -> MappedPageRangeInclusiveDisplay<'_, S> { - MappedPageRangeInclusiveDisplay { inner: self } - } - } - - pub struct MappedPageRangeInclusiveDisplay<'a, S: PageSize> { - inner: &'a MappedPageRangeInclusive, - } - - impl fmt::Display for MappedPageRangeInclusiveDisplay<'_, S> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let size = S::DEBUG_STR; - let len = self.inner.page_range.len(); - let page_start = self.inner.page_range.start.start_address(); - let page_end = self.inner.page_range.end.start_address(); - let frame_start = self.inner.frame_range.start.start_address(); - let frame_end = self.inner.frame_range.end.start_address(); - let flags = self.inner.flags; - let format_phys = if page_start.as_u64() == frame_start.as_u64() { - assert_eq!(page_end.as_u64(), frame_end.as_u64()); - format_args!("{:>39}", "identity mapped") - } else { - format_args!("{frame_start:18p}..={frame_end:18p}") - }; - write!( - f, - "size: {size}, len: {len:5}, virt: {page_start:18p}..={page_end:18p}, phys: {format_phys}, flags: {flags:?}" - ) - } - } -} - -pub mod mapped_page_table_iter { - //! TODO: try to upstream this to [`x86_64`]. - - use core::fmt; - use core::ops::{Add, AddAssign, Sub, SubAssign}; - - use x86_64::structures::paging::frame::PhysFrameRangeInclusive; - use x86_64::structures::paging::mapper::PageTableFrameMapping; - use x86_64::structures::paging::page::{AddressNotAligned, PageRangeInclusive}; - use x86_64::structures::paging::{ - MappedPageTable, OffsetPageTable, Page, PageSize, PageTable, PageTableFlags, - PageTableIndex, PhysFrame, Size1GiB, Size2MiB, Size4KiB, - }; - - use super::offset_page_table::PhysOffset; - use super::walker::{PageTableWalkError, PageTableWalker}; - - #[derive(Debug)] - pub struct MappedPageRangeInclusive { - pub page_range: PageRangeInclusive, - pub frame_range: PhysFrameRangeInclusive, - pub flags: PageTableFlags, - } - - impl TryFrom<(MappedPage, MappedPage)> for MappedPageRangeInclusive { - type Error = TryFromMappedPageError; - - fn try_from((start, end): (MappedPage, MappedPage)) -> Result { - if start.flags != end.flags { - return Err(TryFromMappedPageError); - } - - Ok(Self { - page_range: PageRangeInclusive { - start: start.page, - end: end.page, - }, - frame_range: PhysFrameRangeInclusive { - start: start.frame, - end: end.frame, - }, - flags: start.flags, - }) - } - } - - #[derive(Debug)] - pub enum MappedPageRangeInclusiveItem { - Size4KiB(MappedPageRangeInclusive), - Size2MiB(MappedPageRangeInclusive), - Size1GiB(MappedPageRangeInclusive), - } - - impl TryFrom<(MappedPageItem, MappedPageItem)> for MappedPageRangeInclusiveItem { - type Error = TryFromMappedPageError; - - fn try_from((start, end): (MappedPageItem, MappedPageItem)) -> Result { - match (start, end) { - (MappedPageItem::Size4KiB(start), MappedPageItem::Size4KiB(end)) => { - let range = MappedPageRangeInclusive::try_from((start, end))?; - Ok(Self::Size4KiB(range)) - } - (MappedPageItem::Size2MiB(start), MappedPageItem::Size2MiB(end)) => { - let range = MappedPageRangeInclusive::try_from((start, end))?; - Ok(Self::Size2MiB(range)) - } - (MappedPageItem::Size1GiB(start), MappedPageItem::Size1GiB(end)) => { - let range = MappedPageRangeInclusive::try_from((start, end))?; - Ok(Self::Size1GiB(range)) - } - (_, _) => Err(TryFromMappedPageError), - } - } - } - - #[derive(PartialEq, Eq, Clone, Debug)] - pub struct TryFromMappedPageError; - - impl fmt::Display for TryFromMappedPageError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("provided mapped pages were not compatible") - } - } - - #[derive(Clone)] - pub struct MappedPageTableRangeInclusiveIter<'a, P: PageTableFrameMapping> { - inner: MappedPageTableIter<'a, P>, - start: Option, - end: Option, - } - - #[expect(dead_code)] - pub fn mapped_page_table_range_iter<'a, P: PageTableFrameMapping>( - page_table: &'a MappedPageTable<'a, P>, - ) -> MappedPageTableRangeInclusiveIter<'a, &'a P> { - MappedPageTableRangeInclusiveIter { - inner: mapped_page_table_iter(page_table), - start: None, - end: None, - } - } - - pub fn offset_page_table_range_iter<'a>( - page_table: &'a OffsetPageTable<'a>, - ) -> MappedPageTableRangeInclusiveIter<'a, PhysOffset> { - MappedPageTableRangeInclusiveIter { - inner: offset_page_table_iter(page_table), - start: None, - end: None, - } - } - - impl<'a, P: PageTableFrameMapping> Iterator for MappedPageTableRangeInclusiveIter<'a, P> { - type Item = MappedPageRangeInclusiveItem; - - fn next(&mut self) -> Option { - if self.start.is_none() { - self.start = self.inner.next(); - self.end = self.start; - } - - let Some(start) = &mut self.start else { - return None; - }; - let end = self.end.as_mut().unwrap(); - - for mapped_page in self.inner.by_ref() { - if mapped_page == *end + 1 { - *end = mapped_page; - continue; - } - - let range = MappedPageRangeInclusiveItem::try_from((*start, *end)).unwrap(); - *start = mapped_page; - *end = mapped_page; - return Some(range); - } - - let range = MappedPageRangeInclusiveItem::try_from((*start, *end)).unwrap(); - self.start = None; - self.end = None; - Some(range) - } - } - - #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)] - pub struct MappedPage { - pub page: Page, - pub frame: PhysFrame, - pub flags: PageTableFlags, - } - - impl Add for MappedPage { - type Output = Self; - - fn add(self, rhs: u64) -> Self::Output { - Self { - page: self.page + rhs, - frame: self.frame + rhs, - flags: self.flags, - } - } - } - - impl Sub for MappedPage { - type Output = Self; - - fn sub(self, rhs: u64) -> Self::Output { - Self { - page: self.page - rhs, - frame: self.frame - rhs, - flags: self.flags, - } - } - } - - #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)] - pub enum MappedPageItem { - Size4KiB(MappedPage), - Size2MiB(MappedPage), - Size1GiB(MappedPage), - } - - impl Add for MappedPageItem { - type Output = Self; - - fn add(self, rhs: u64) -> Self::Output { - match self { - Self::Size4KiB(mapped_page) => Self::Size4KiB(mapped_page + rhs), - Self::Size2MiB(mapped_page) => Self::Size2MiB(mapped_page + rhs), - Self::Size1GiB(mapped_page) => Self::Size1GiB(mapped_page + rhs), - } - } - } - - impl AddAssign for MappedPageItem { - fn add_assign(&mut self, rhs: u64) { - *self = *self + rhs; - } - } - - impl Sub for MappedPageItem { - type Output = Self; - - fn sub(self, rhs: u64) -> Self::Output { - match self { - Self::Size4KiB(mapped_page) => Self::Size4KiB(mapped_page - rhs), - Self::Size2MiB(mapped_page) => Self::Size2MiB(mapped_page - rhs), - Self::Size1GiB(mapped_page) => Self::Size1GiB(mapped_page - rhs), - } - } - } - - impl SubAssign for MappedPageItem { - fn sub_assign(&mut self, rhs: u64) { - *self = *self - rhs; - } - } - - #[derive(Clone)] - pub struct MappedPageTableIter<'a, P: PageTableFrameMapping> { - page_table_walker: PageTableWalker

, - level_4_table: &'a PageTable, - p4_index: u16, - p3_index: u16, - p2_index: u16, - p1_index: u16, - } - - pub fn mapped_page_table_iter<'a, P: PageTableFrameMapping>( - page_table: &'a MappedPageTable<'a, P>, - ) -> MappedPageTableIter<'a, &'a P> { - MappedPageTableIter { - page_table_walker: unsafe { - PageTableWalker::new(page_table.page_table_frame_mapping()) - }, - level_4_table: page_table.level_4_table(), - p4_index: 0, - p3_index: 0, - p2_index: 0, - p1_index: 0, - } - } - - pub fn offset_page_table_iter<'a>( - page_table: &'a OffsetPageTable<'a>, - ) -> MappedPageTableIter<'a, PhysOffset> { - MappedPageTableIter { - page_table_walker: unsafe { - PageTableWalker::new(PhysOffset { - offset: page_table.phys_offset(), - }) - }, - level_4_table: page_table.level_4_table(), - p4_index: 0, - p3_index: 0, - p2_index: 0, - p1_index: 0, - } - } - - impl<'a, P: PageTableFrameMapping> MappedPageTableIter<'a, P> { - fn p4_index(&self) -> Option { - if self.p4_index >= 512 { - return None; - } - - Some(PageTableIndex::new(self.p4_index)) - } - - fn p3_index(&self) -> Option { - if self.p3_index >= 512 { - return None; - } - - Some(PageTableIndex::new(self.p3_index)) - } - - fn p2_index(&self) -> Option { - if self.p2_index >= 512 { - return None; - } - - Some(PageTableIndex::new(self.p2_index)) - } - - fn p1_index(&self) -> Option { - if self.p1_index >= 512 { - return None; - } - - Some(PageTableIndex::new(self.p1_index)) - } - - fn increment_p4_index(&mut self) -> Option<()> { - if self.p4_index >= 511 { - self.p4_index += 1; - return None; - } - - self.p4_index += 1; - self.p3_index = 0; - self.p2_index = 0; - self.p1_index = 0; - Some(()) - } - - fn increment_p3_index(&mut self) -> Option<()> { - if self.p3_index == 511 { - self.increment_p4_index()?; - return None; - } - - self.p3_index += 1; - self.p2_index = 0; - self.p1_index = 0; - Some(()) - } - - fn increment_p2_index(&mut self) -> Option<()> { - if self.p2_index == 511 { - self.increment_p3_index()?; - return None; - } - - self.p2_index += 1; - self.p1_index = 0; - Some(()) - } - - fn increment_p1_index(&mut self) -> Option<()> { - if self.p1_index == 511 { - self.increment_p2_index()?; - return None; - } - - self.p1_index += 1; - Some(()) - } - - fn next_forward(&mut self) -> Option { - let p4 = self.level_4_table; - - let p3 = loop { - match self.page_table_walker.next_table(&p4[self.p4_index()?]) { - Ok(page_table) => break page_table, - Err(PageTableWalkError::NotMapped) => self.increment_p4_index()?, - Err(PageTableWalkError::MappedToHugePage) => { - panic!("level 4 entry has huge page bit set") - } - } - }; - - let p2 = loop { - match self.page_table_walker.next_table(&p3[self.p3_index()?]) { - Ok(page_table) => break page_table, - Err(PageTableWalkError::NotMapped) => self.increment_p3_index()?, - Err(PageTableWalkError::MappedToHugePage) => { - let page = - Page::from_page_table_indices_1gib(self.p4_index()?, self.p3_index()?); - let entry = &p3[self.p3_index()?]; - let frame = PhysFrame::containing_address(entry.addr()); - let flags = entry.flags(); - let mapped_page = - MappedPageItem::Size1GiB(MappedPage { page, frame, flags }); - - self.increment_p3_index(); - return Some(mapped_page); - } - } - }; - - let p1 = loop { - match self.page_table_walker.next_table(&p2[self.p2_index()?]) { - Ok(page_table) => break page_table, - Err(PageTableWalkError::NotMapped) => self.increment_p2_index()?, - Err(PageTableWalkError::MappedToHugePage) => { - let page = Page::from_page_table_indices_2mib( - self.p4_index()?, - self.p3_index()?, - self.p2_index()?, - ); - let entry = &p2[self.p2_index()?]; - let frame = PhysFrame::containing_address(entry.addr()); - let flags = entry.flags(); - let mapped_page = - MappedPageItem::Size2MiB(MappedPage { page, frame, flags }); - - self.increment_p2_index(); - return Some(mapped_page); - } - } - }; - - loop { - let p1_entry = &p1[self.p1_index()?]; - - if p1_entry.is_unused() { - self.increment_p1_index()?; - continue; - } - - let frame = match PhysFrame::from_start_address(p1_entry.addr()) { - Ok(frame) => frame, - Err(AddressNotAligned) => { - warn!("Invalid frame address: {:p}", p1_entry.addr()); - self.increment_p1_index()?; - continue; - } - }; - - let page = Page::from_page_table_indices( - self.p4_index()?, - self.p3_index()?, - self.p2_index()?, - self.p1_index()?, - ); - let flags = p1_entry.flags(); - let mapped_page = MappedPageItem::Size4KiB(MappedPage { page, frame, flags }); - - self.increment_p1_index(); - return Some(mapped_page); - } - } - } - - impl<'a, P: PageTableFrameMapping> Iterator for MappedPageTableIter<'a, P> { - type Item = MappedPageItem; - - fn next(&mut self) -> Option { - self.next_forward().or_else(|| self.next_forward()) - } - } -} - -mod walker { - //! Taken from [`x86_64`] - - use x86_64::structures::paging::PageTable; - use x86_64::structures::paging::mapper::PageTableFrameMapping; - use x86_64::structures::paging::page_table::{FrameError, PageTableEntry}; - - #[derive(Clone, Debug)] - pub(super) struct PageTableWalker { - page_table_frame_mapping: P, - } - - impl PageTableWalker

{ - #[inline] - pub unsafe fn new(page_table_frame_mapping: P) -> Self { - Self { - page_table_frame_mapping, - } - } - - /// Internal helper function to get a reference to the page table of the next level. - /// - /// Returns `PageTableWalkError::NotMapped` if the entry is unused. Returns - /// `PageTableWalkError::MappedToHugePage` if the `HUGE_PAGE` flag is set - /// in the passed entry. - #[inline] - pub(super) fn next_table<'b>( - &self, - entry: &'b PageTableEntry, - ) -> Result<&'b PageTable, PageTableWalkError> { - let page_table_ptr = self - .page_table_frame_mapping - .frame_to_pointer(entry.frame()?); - let page_table: &PageTable = unsafe { &*page_table_ptr }; - - Ok(page_table) - } - } - - #[derive(Debug)] - pub(super) enum PageTableWalkError { - NotMapped, - MappedToHugePage, - } - - impl From for PageTableWalkError { - #[inline] - fn from(err: FrameError) -> Self { - match err { - FrameError::HugeFrame => PageTableWalkError::MappedToHugePage, - FrameError::FrameNotPresent => PageTableWalkError::NotMapped, - } - } - } -} - -mod offset_page_table { - //! Taken from [`x86_64`] - - use x86_64::VirtAddr; - use x86_64::structures::paging::mapper::PageTableFrameMapping; - use x86_64::structures::paging::{PageTable, PhysFrame}; - - #[derive(Clone, Debug)] - pub struct PhysOffset { - pub offset: VirtAddr, - } - - unsafe impl PageTableFrameMapping for PhysOffset { - fn frame_to_pointer(&self, frame: PhysFrame) -> *mut PageTable { - let virt = self.offset + frame.start_address().as_u64(); - virt.as_mut_ptr() - } - } + info!("Page tables:\n{}", page_table.display()); } diff --git a/src/drivers/console/mod.rs b/src/drivers/console/mod.rs index 1d8d7d8805..8ebe85410b 100644 --- a/src/drivers/console/mod.rs +++ b/src/drivers/console/mod.rs @@ -307,7 +307,8 @@ impl VirtioConsoleDriver { let negotiated_features = self .com_cfg .control_registers() - .negotiate_features(minimal_features); + .negotiate_features(minimal_features) + .unwrap(); if !negotiated_features.contains(minimal_features) { error!("Device features set, does not satisfy minimal features needed. Aborting!"); diff --git a/src/drivers/fs/mod.rs b/src/drivers/fs/mod.rs index 54995f3af5..f63a112b5d 100644 --- a/src/drivers/fs/mod.rs +++ b/src/drivers/fs/mod.rs @@ -99,7 +99,8 @@ impl VirtioFsDriver { let negotiated_features = self .com_cfg .control_registers() - .negotiate_features(minimal_features); + .negotiate_features(minimal_features) + .unwrap(); if !negotiated_features.contains(minimal_features) { error!("Device features set, does not satisfy minimal features needed. Aborting!"); diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index 3b508e8281..203c3affe3 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -675,7 +675,8 @@ impl VirtioNetDriver { let negotiated_features = self .com_cfg .control_registers() - .negotiate_features(features); + .negotiate_features(features) + .unwrap(); if !negotiated_features.contains(minimal_features) { error!("Device features set, does not satisfy minimal features needed. Aborting!"); diff --git a/src/drivers/virtio/mod.rs b/src/drivers/virtio/mod.rs index af4d855851..a1555c6184 100644 --- a/src/drivers/virtio/mod.rs +++ b/src/drivers/virtio/mod.rs @@ -15,9 +15,12 @@ pub mod transport; pub mod virtqueue; -use core::fmt; +use core::{fmt, mem}; -use virtio::FeatureBits; +use virtio::{DeviceStatus, FeatureBits}; + +use crate::errno::Errno; +use crate::io; trait VirtioIdExt { fn as_feature(&self) -> Option<&str>; @@ -44,9 +47,11 @@ mod control_registers_access { use volatile::VolatilePtr; use volatile::access::ReadWrite; - pub trait ControlRegistersAccess<'a>: Sized + Copy { + pub trait ControlRegistersAccess: Sized + Copy { fn read_device_feature_word(self, i: u32) -> le32; fn write_driver_feature_word(self, i: u32, word: le32); + fn read_device_status(self) -> virtio::DeviceStatus; + fn write_device_status(self, device_status: virtio::DeviceStatus); fn read_device_features(self) -> virtio::F { let features = array::from_fn(|i| { @@ -69,10 +74,15 @@ mod control_registers_access { self.write_driver_feature_word(i, word); } } + + fn add_device_status(self, device_status: virtio::DeviceStatus) { + let device_status = self.read_device_status() | device_status; + self.write_device_status(device_status); + } } #[cfg(feature = "pci")] - impl<'a> ControlRegistersAccess<'a> for VolatilePtr<'a, virtio::pci::CommonCfg, ReadWrite> { + impl<'a> ControlRegistersAccess for VolatilePtr<'a, virtio::pci::CommonCfg, ReadWrite> { fn read_device_feature_word(self, i: u32) -> le32 { use virtio::pci::CommonCfgVolatileFieldAccess; @@ -86,10 +96,22 @@ mod control_registers_access { self.driver_feature_select().write(i.into()); self.driver_feature().write(word); } + + fn read_device_status(self) -> virtio::DeviceStatus { + use virtio::pci::CommonCfgVolatileFieldAccess; + + self.device_status().read() + } + + fn write_device_status(self, device_status: virtio::DeviceStatus) { + use virtio::pci::CommonCfgVolatileFieldAccess; + + self.device_status().write(device_status); + } } #[cfg(not(feature = "pci"))] - impl<'a> ControlRegistersAccess<'a> for VolatilePtr<'a, virtio::mmio::DeviceRegisters, ReadWrite> { + impl<'a> ControlRegistersAccess for VolatilePtr<'a, virtio::mmio::DeviceRegisters, ReadWrite> { fn read_device_feature_word(self, i: u32) -> le32 { use virtio::mmio::DeviceRegistersVolatileFieldAccess; @@ -116,50 +138,162 @@ mod control_registers_access { self.driver_features_sel().write(i.into()); self.driver_features().write(word); } + + fn read_device_status(self) -> virtio::DeviceStatus { + use virtio::pci::CommonCfgVolatileFieldAccess; + + self.device_status().read() + } + + fn write_device_status(self, device_status: virtio::DeviceStatus) { + use virtio::pci::CommonCfgVolatileFieldAccess; + + self.device_status().write(device_status); + } } } -pub trait ControlRegisters<'a>: self::control_registers_access::ControlRegistersAccess<'a> { - fn negotiate_features(self, driver_features: DF) -> DF +pub trait ControlRegisters: self::control_registers_access::ControlRegistersAccess { + fn negotiate_features(self, driver_features: DF) -> io::Result where DF: FeatureBits + From + AsRef + AsMut + fmt::Debug + Copy, virtio::F: From + AsRef + AsMut; + + fn init_device( + self, + driver: D, + id: virtio::Id, + setup: impl FnOnce() -> io::Result<()>, + ) -> io::Result<()> + where + D: VirtioDriver, + D::F: + FeatureBits + From + AsRef + AsMut + fmt::Debug + Copy, + virtio::F: From + AsRef + AsMut; } -impl<'a, T> ControlRegisters<'a> for T +impl ControlRegisters for T where - T: self::control_registers_access::ControlRegistersAccess<'a>, + T: self::control_registers_access::ControlRegistersAccess, { - fn negotiate_features(self, driver_features: DF) -> DF + fn negotiate_features(self, driver_features: DF) -> io::Result where DF: FeatureBits + From + AsRef + AsMut + fmt::Debug + Copy, virtio::F: From + AsRef + AsMut, { let device_features = DF::from(self.read_device_features()); info!("device_features = {device_features:?}"); - debug_assert!( - device_features.requirements_satisfied(), - "The device offers a feature which requires another feature which was not offered." - ); + if !device_features.requirements_satisfied() { + error!( + "The device offers a feature which requires another feature which was not offered." + ); + return Err(Errno::Inval); + } info!("driver_features = {driver_features:?}"); - debug_assert!( - driver_features.requirements_satisfied(), - "The driver offers a feature which requires another feature which was not offered.", - ); + if !driver_features.requirements_satisfied() { + error!( + "The driver offers a feature which requires another feature which was not offered." + ); + return Err(Errno::Inval); + } let common_features = device_features.intersection(driver_features); info!("common_features = {common_features:?}"); - // This should be logically unreachable. - debug_assert!( - common_features.requirements_satisfied(), - "We negotiated a feature which requires another feature which was not negotiated." - ); + if !common_features.requirements_satisfied() { + // This should be logically unreachable. + error!( + "We negotiated a feature which requires another feature which was not negotiated." + ); + return Err(Errno::Inval); + } self.write_driver_features(common_features.into()); - common_features + Ok(common_features) } + + fn init_device( + self, + driver: D, + id: virtio::Id, + setup: impl FnOnce() -> io::Result<()>, + ) -> io::Result<()> + where + D: VirtioDriver, + D::F: + FeatureBits + From + AsRef + AsMut + fmt::Debug + Copy, + virtio::F: From + AsRef + AsMut, + { + struct FailOnDrop(T) + where + T: self::control_registers_access::ControlRegistersAccess; + + impl Drop for FailOnDrop + where + T: self::control_registers_access::ControlRegistersAccess, + { + fn drop(&mut self) { + self.0.write_device_status(DeviceStatus::FAILED); + } + } + + let fail_on_drop = FailOnDrop(self); + + // Reset the device. + self.write_device_status(DeviceStatus::empty()); + + // Tell the device that we have noticed it. + self.add_device_status(DeviceStatus::ACKNOWLEDGE); + + if id != virtio::Id::Net { + if let Some(feature) = id.as_feature() { + error!("Virtio driver {id:?} is currently not active."); + error!("To use the device, recompile the kernel with the {feature} feature."); + } else { + error!("Virtio device {id:?} is not supported!"); + } + return Err(Errno::Nodev); + } + + // Tell the device that we know how to drive it. + self.add_device_status(DeviceStatus::DRIVER); + + let negotiated_features = self.negotiate_features(D::FEATURES); + + // Tell the device to check the features. + self.add_device_status(DeviceStatus::FEATURES_OK); + + // Check whether the device supports our subset of features. + if !self + .read_device_status() + .contains(DeviceStatus::FEATURES_OK) + { + error!("The device does not support our subset of features."); + return Err(Errno::Nodev); + } + + setup()?; + + self.add_device_status(DeviceStatus::DRIVER_OK); + + mem::forget(fail_on_drop); + Ok(()) + } +} + +trait VirtioDriver { + type F; + + const FEATURES: Self::F; +} + +struct Net; + +impl VirtioDriver for Net { + type F = virtio::net::F; + + const FEATURES: Self::F = virtio::net::F::VERSION_1; } pub mod error { diff --git a/src/drivers/virtio/transport/mmio.rs b/src/drivers/virtio/transport/mmio.rs index c92d80de9c..a06584e61f 100644 --- a/src/drivers/virtio/transport/mmio.rs +++ b/src/drivers/virtio/transport/mmio.rs @@ -105,7 +105,7 @@ impl ComCfg { ComCfg { com_cfg: raw } } - pub fn control_registers(&mut self) -> impl ControlRegisters<'_> { + pub fn control_registers(&mut self) -> impl ControlRegisters { self.com_cfg.as_mut_ptr() } diff --git a/src/drivers/virtio/transport/pci.rs b/src/drivers/virtio/transport/pci.rs index a91e81fe1e..c1178b9f6c 100644 --- a/src/drivers/virtio/transport/pci.rs +++ b/src/drivers/virtio/transport/pci.rs @@ -269,7 +269,7 @@ impl VqCfgHandler<'_> { // Public Interface of ComCfg impl ComCfg { - pub fn control_registers(&mut self) -> impl ControlRegisters<'_> { + pub fn control_registers(&mut self) -> impl ControlRegisters { self.com_cfg.as_mut_ptr() } diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 71c39c894c..6c4268f11a 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -325,7 +325,8 @@ impl VirtioVsockDriver { let negotiated_features = self .com_cfg .control_registers() - .negotiate_features(minimal_features); + .negotiate_features(minimal_features) + .unwrap(); if !negotiated_features.contains(minimal_features) { error!("Device features set, does not satisfy minimal features needed. Aborting!");