From 31f6a10594a9ddce9566ac2553277b48b6454252 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 7 Mar 2026 09:33:59 +0100 Subject: [PATCH] make VirtAddr generic over address validity As proposed by Joe and Philipp, this patch adds a generic parameter for VirtAddr. This generic parameter can be used to enforce validity for addresses with 4 level paging or 5 level paging. This patch is still a work in progress because I wanted to make that this is what you had in mind before I spend more time on it. I slightly modified the VirtValidity interface to because check_virt_addr by itself cannot be used to implement truncation, but truncation can be used to implement check_virt_addr. That's just a minor implementation detail though, we can discuss this later. TODO: - Add a feature to gate enabling the const_trait_impl feature. - Fix up the remaining code. I didn't adjust all of the code to take a generic parameter for the validity yet. As of right now, there are a ton of compilation errors. - Figure out how to make the mapper impls work with VirtValidity. VirtValidity by itself doesn't tell the mappers how many levels there are. I don't expect this to be difficult to solve. - Default the validity generic parameter to Virt48. Currently, I intentionally left the default out, so that the compiler would tell me about all the code that needed to become aware of it. Rust doesn't allow defaulted generic parameters after non-defaulted ones, so I currently put the validity parameter before the page size parameter for Page. We'll probably want to swap that once we set the default. - Test that the Add/Sub/Step impls work and fix them if they don't. Up until now I mostly focused on fixing compilation errors and not so much on the code actually working. - Add type aliases? User would probably benefit from VirtAddr48, VirtAddr57, Page48, and Page57. - Do the same thing for PhysAddr. - Rebase this on next. This will very likely be a breaking change. --- src/addr.rs | 96 ++++++++++---- src/instructions/mod.rs | 4 +- src/instructions/segmentation.rs | 5 +- src/instructions/tables.rs | 13 +- src/instructions/tlb.rs | 32 ++--- src/lib.rs | 1 + src/registers/control.rs | 8 +- src/registers/model_specific.rs | 17 +-- src/registers/segmentation.rs | 6 +- src/structures/idt.rs | 31 ++--- src/structures/mod.rs | 6 +- .../paging/mapper/mapped_page_table.rs | 13 +- src/structures/paging/mapper/mod.rs | 17 ++- src/structures/paging/page.rs | 123 +++++++++++------- src/structures/tss.rs | 19 +-- 15 files changed, 245 insertions(+), 146 deletions(-) diff --git a/src/addr.rs b/src/addr.rs index 232e96a16..2038b1497 100644 --- a/src/addr.rs +++ b/src/addr.rs @@ -2,8 +2,10 @@ use core::convert::TryFrom; use core::fmt; +use core::hash::Hash; #[cfg(feature = "step_trait")] use core::iter::Step; +use core::marker::PhantomData; use core::ops::{Add, AddAssign, Sub, SubAssign}; #[cfg(feature = "memory_encryption")] use core::sync::atomic::Ordering; @@ -30,7 +32,33 @@ const ADDRESS_SPACE_SIZE: u64 = 0x1_0000_0000_0000; /// are called “canonical”. This type guarantees that it always represents a canonical address. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] -pub struct VirtAddr(u64); +pub struct VirtAddr(u64, PhantomData); + +pub const trait VirtValidity: Clone + Copy + Eq + Ord + Hash { + fn truncate(addr: u64) -> u64; +} + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Virt48 {} + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Virt57 {} + +impl const VirtValidity for Virt48 { + fn truncate(addr: u64) -> u64 { + // By doing the right shift as a signed operation (on a i64), it will + // sign extend the value, repeating the leftmost bit. + ((addr << 16) as i64 >> 16) as u64 + } +} + +impl const VirtValidity for Virt57 { + fn truncate(addr: u64) -> u64 { + // By doing the right shift as a signed operation (on a i64), it will + // sign extend the value, repeating the leftmost bit. + ((addr << 7) as i64 >> 7) as u64 + } +} /// A 64-bit physical memory address. /// @@ -63,7 +91,7 @@ impl core::fmt::Debug for VirtAddrNotValid { } } -impl VirtAddr { +impl VirtAddr { /// Creates a new canonical virtual address. /// /// The provided address should already be canonical. If you want to check @@ -74,7 +102,10 @@ impl VirtAddr { /// This function panics if the bits in the range 48 to 64 are invalid /// (i.e. are not a proper sign extension of bit 47). #[inline] - pub const fn new(addr: u64) -> VirtAddr { + pub const fn new(addr: u64) -> Self + where + V: [const] VirtValidity, + { // TODO: Replace with .ok().expect(msg) when that works on stable. match Self::try_new(addr) { Ok(v) => v, @@ -89,7 +120,10 @@ impl VirtAddr { /// if bits 48 to 64 are a correct sign /// extension (i.e. copies of bit 47). #[inline] - pub const fn try_new(addr: u64) -> Result { + pub const fn try_new(addr: u64) -> Result + where + V: [const] VirtValidity, + { let v = Self::new_truncate(addr); if v.0 == addr { Ok(v) @@ -104,10 +138,13 @@ impl VirtAddr { /// canonical, overwriting bits 48 to 64. If you want to check whether an /// address is canonical, use [`new`](Self::new) or [`try_new`](Self::try_new). #[inline] - pub const fn new_truncate(addr: u64) -> VirtAddr { + pub const fn new_truncate(addr: u64) -> Self + where + V: [const] VirtValidity, + { // By doing the right shift as a signed operation (on a i64), it will // sign extend the value, repeating the leftmost bit. - VirtAddr(((addr << 16) as i64 >> 16) as u64) + Self(V::truncate(addr), PhantomData) } /// Creates a new virtual address, without any checks. @@ -116,14 +153,17 @@ impl VirtAddr { /// /// You must make sure bits 48..64 are equal to bit 47. This is not checked. #[inline] - pub const unsafe fn new_unsafe(addr: u64) -> VirtAddr { - VirtAddr(addr) + pub const unsafe fn new_unsafe(addr: u64) -> Self { + Self(addr, PhantomData) } /// Creates a virtual address that points to `0`. #[inline] - pub const fn zero() -> VirtAddr { - VirtAddr(0) + pub const fn zero() -> Self + where + V: [const] VirtValidity, + { + VirtAddr::new(0) } /// Converts the address to an `u64`. @@ -190,7 +230,10 @@ impl VirtAddr { /// /// See the `align_down` function for more information. #[inline] - pub(crate) const fn align_down_u64(self, align: u64) -> Self { + pub(crate) const fn align_down_u64(self, align: u64) -> Self + where + V: [const] VirtValidity, + { VirtAddr::new_truncate(align_down(self.0, align)) } @@ -205,7 +248,10 @@ impl VirtAddr { /// Checks whether the virtual address has the demanded alignment. #[inline] - pub(crate) const fn is_aligned_u64(self, align: u64) -> bool { + pub(crate) const fn is_aligned_u64(self, align: u64) -> bool + where + V: [const] VirtValidity, + { self.align_down_u64(align).as_u64() == self.as_u64() } @@ -325,7 +371,7 @@ impl VirtAddr { } } -impl fmt::Debug for VirtAddr { +impl fmt::Debug for VirtAddr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_tuple("VirtAddr") .field(&format_args!("{:#x}", self.0)) @@ -333,42 +379,42 @@ impl fmt::Debug for VirtAddr { } } -impl fmt::Binary for VirtAddr { +impl fmt::Binary for VirtAddr { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Binary::fmt(&self.0, f) } } -impl fmt::LowerHex for VirtAddr { +impl fmt::LowerHex for VirtAddr { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::LowerHex::fmt(&self.0, f) } } -impl fmt::Octal for VirtAddr { +impl fmt::Octal for VirtAddr { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Octal::fmt(&self.0, f) } } -impl fmt::UpperHex for VirtAddr { +impl fmt::UpperHex for VirtAddr { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::UpperHex::fmt(&self.0, f) } } -impl fmt::Pointer for VirtAddr { +impl fmt::Pointer for VirtAddr { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Pointer::fmt(&(self.0 as *const ()), f) } } -impl Add for VirtAddr { +impl Add for VirtAddr { type Output = Self; #[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))] @@ -393,7 +439,7 @@ impl Add for VirtAddr { } } -impl AddAssign for VirtAddr { +impl AddAssign for VirtAddr { #[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))] /// Add an offset to a virtual address. /// @@ -411,7 +457,7 @@ impl AddAssign for VirtAddr { } } -impl Sub for VirtAddr { +impl Sub for VirtAddr { type Output = Self; #[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))] @@ -436,7 +482,7 @@ impl Sub for VirtAddr { } } -impl SubAssign for VirtAddr { +impl SubAssign for VirtAddr { #[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))] /// Subtract an offset from a virtual address. /// @@ -454,7 +500,7 @@ impl SubAssign for VirtAddr { } } -impl Sub for VirtAddr { +impl Sub> for VirtAddr { type Output = u64; /// Returns the difference between two addresses. @@ -463,7 +509,7 @@ impl Sub for VirtAddr { /// /// This function will panic on overflow. #[inline] - fn sub(self, rhs: VirtAddr) -> Self::Output { + fn sub(self, rhs: VirtAddr) -> Self::Output { self.as_u64() .checked_sub(rhs.as_u64()) .expect("attempt to subtract with overflow") @@ -471,7 +517,7 @@ impl Sub for VirtAddr { } #[cfg(feature = "step_trait")] -impl Step for VirtAddr { +impl Step for VirtAddr { #[inline] fn steps_between(start: &Self, end: &Self) -> (usize, Option) { Self::steps_between_impl(start, end) diff --git a/src/instructions/mod.rs b/src/instructions/mod.rs index 91751ede3..fcd52e9a5 100644 --- a/src/instructions/mod.rs +++ b/src/instructions/mod.rs @@ -12,6 +12,8 @@ pub mod tlb; use core::arch::asm; +use crate::addr::VirtValidity; + /// Halts the CPU until the next interrupt arrives. #[inline] pub fn hlt() { @@ -48,7 +50,7 @@ pub fn bochs_breakpoint() { /// Gets the current instruction pointer. Note that this is only approximate as it requires a few /// instructions to execute. #[inline(always)] -pub fn read_rip() -> crate::VirtAddr { +pub fn read_rip() -> crate::VirtAddr { let rip: u64; unsafe { asm!("lea {}, [rip]", out(reg) rip, options(nostack, nomem, preserves_flags)); diff --git a/src/instructions/segmentation.rs b/src/instructions/segmentation.rs index c8fbfbabd..e83c119a8 100644 --- a/src/instructions/segmentation.rs +++ b/src/instructions/segmentation.rs @@ -2,6 +2,7 @@ pub use crate::registers::segmentation::{Segment, Segment64, CS, DS, ES, FS, GS, SS}; use crate::{ + addr::VirtValidity, registers::model_specific::{FsBase, GsBase, Msr}, structures::gdt::SegmentSelector, VirtAddr, @@ -41,7 +42,7 @@ macro_rules! segment64_impl { impl Segment64 for $type { const BASE: Msr = <$base>::MSR; #[inline] - fn read_base() -> VirtAddr { + fn read_base() -> VirtAddr { unsafe { let val: u64; asm!(concat!("rd", $name, "base {}"), out(reg) val, options(nomem, nostack, preserves_flags)); @@ -50,7 +51,7 @@ macro_rules! segment64_impl { } #[inline] - unsafe fn write_base(base: VirtAddr) { + unsafe fn write_base(base: VirtAddr) { unsafe{ asm!(concat!("wr", $name, "base {}"), in(reg) base.as_u64(), options(nostack, preserves_flags)); } diff --git a/src/instructions/tables.rs b/src/instructions/tables.rs index 611d61175..15594d256 100644 --- a/src/instructions/tables.rs +++ b/src/instructions/tables.rs @@ -1,5 +1,6 @@ //! Functions to load GDT, IDT, and TSS structures. +use crate::addr::VirtValidity; use crate::structures::gdt::SegmentSelector; use crate::VirtAddr; use core::arch::asm; @@ -18,7 +19,7 @@ pub use crate::structures::DescriptorTablePointer; /// `DescriptorTablePointer` points to a valid GDT and that loading this /// GDT is safe. #[inline] -pub unsafe fn lgdt(gdt: &DescriptorTablePointer) { +pub unsafe fn lgdt(gdt: &DescriptorTablePointer) { unsafe { asm!("lgdt [{}]", in(reg) gdt, options(readonly, nostack, preserves_flags)); } @@ -36,7 +37,7 @@ pub unsafe fn lgdt(gdt: &DescriptorTablePointer) { /// `DescriptorTablePointer` points to a valid IDT and that loading this /// IDT is safe. #[inline] -pub unsafe fn lidt(idt: &DescriptorTablePointer) { +pub unsafe fn lidt(idt: &DescriptorTablePointer) { unsafe { asm!("lidt [{}]", in(reg) idt, options(readonly, nostack, preserves_flags)); } @@ -44,8 +45,8 @@ pub unsafe fn lidt(idt: &DescriptorTablePointer) { /// Get the address of the current GDT. #[inline] -pub fn sgdt() -> DescriptorTablePointer { - let mut gdt: DescriptorTablePointer = DescriptorTablePointer { +pub fn sgdt() -> DescriptorTablePointer { + let mut gdt = DescriptorTablePointer { limit: 0, base: VirtAddr::new(0), }; @@ -57,8 +58,8 @@ pub fn sgdt() -> DescriptorTablePointer { /// Get the address of the current IDT. #[inline] -pub fn sidt() -> DescriptorTablePointer { - let mut idt: DescriptorTablePointer = DescriptorTablePointer { +pub fn sidt() -> DescriptorTablePointer { + let mut idt = DescriptorTablePointer { limit: 0, base: VirtAddr::new(0), }; diff --git a/src/instructions/tlb.rs b/src/instructions/tlb.rs index d96cc8955..73ede4023 100644 --- a/src/instructions/tlb.rs +++ b/src/instructions/tlb.rs @@ -3,6 +3,7 @@ use bit_field::BitField; use crate::{ + addr::VirtValidity, instructions::segmentation::{Segment, CS}, structures::paging::{ page::{NotGiantPageSize, PageRange}, @@ -14,7 +15,7 @@ use core::{arch::asm, cmp, convert::TryFrom, fmt}; /// Invalidate the given address in the TLB using the `invlpg` instruction. #[inline] -pub fn flush(addr: VirtAddr) { +pub fn flush(addr: VirtAddr) { unsafe { asm!("invlpg [{}]", in(reg) addr.as_u64(), options(nostack, preserves_flags)); } @@ -30,9 +31,9 @@ pub fn flush_all() { /// The Invalidate PCID Command to execute. #[derive(Debug)] -pub enum InvPcidCommand { +pub enum InvPcidCommand { /// The logical processor invalidates mappings—except global translations—for the linear address and PCID specified. - Address(VirtAddr, Pcid), + Address(VirtAddr, Pcid), /// The logical processor invalidates all mappings—except global translations—associated with the PCID. Single(Pcid), @@ -47,7 +48,7 @@ pub enum InvPcidCommand { // TODO: Remove this in the next breaking release. #[deprecated = "please use `InvPcidCommand` instead"] #[doc(hidden)] -pub type InvPicdCommand = InvPcidCommand; +pub type InvPicdCommand = InvPcidCommand; /// The INVPCID descriptor comprises 128 bits and consists of a PCID and a linear address. /// For INVPCID type 0, the processor uses the full 64 bits of the linear address even outside 64-bit mode; the linear address is not used for other INVPCID types. @@ -98,7 +99,7 @@ impl fmt::Display for PcidTooBig { /// /// This function is unsafe as it requires CPUID.(EAX=07H, ECX=0H):EBX.INVPCID to be 1. #[inline] -pub unsafe fn flush_pcid(command: InvPcidCommand) { +pub unsafe fn flush_pcid(command: InvPcidCommand) { let mut desc = InvpcidDescriptor { pcid: 0, address: 0, @@ -200,7 +201,7 @@ impl Invlpgb { } /// Create a `InvlpgbFlushBuilder`. - pub fn build(&self) -> InvlpgbFlushBuilder<'_> { + pub fn build(&self) -> InvlpgbFlushBuilder<'_, V> { InvlpgbFlushBuilder { invlpgb: self, page_range: None, @@ -225,12 +226,12 @@ impl Invlpgb { /// A builder struct to construct the parameters for the `invlpgb` instruction. #[derive(Debug, Clone)] #[must_use] -pub struct InvlpgbFlushBuilder<'a, S = Size4KiB> +pub struct InvlpgbFlushBuilder<'a, V: VirtValidity, S = Size4KiB> where S: NotGiantPageSize, { invlpgb: &'a Invlpgb, - page_range: Option>, + page_range: Option>, pcid: Option, asid: Option, include_global: bool, @@ -238,15 +239,16 @@ where include_nested_translations: bool, } -impl<'a, S> InvlpgbFlushBuilder<'a, S> +impl<'a, S, V> InvlpgbFlushBuilder<'a, V, S> where S: NotGiantPageSize, + V: VirtValidity, { /// Flush a range of pages. /// /// If the range doesn't fit within `invlpgb_count_max`, `invlpgb` is /// executed multiple times. - pub fn pages(self, page_range: PageRange) -> InvlpgbFlushBuilder<'a, T> + pub fn pages(self, page_range: PageRange) -> InvlpgbFlushBuilder<'a, V, T> where T: NotGiantPageSize, { @@ -317,11 +319,11 @@ where if let Some(mut pages) = self.page_range { while !pages.is_empty() { // Calculate out how many pages we still need to flush. - let count = Page::::steps_between_impl(&pages.start, &pages.end).0; + let count = Page::::steps_between_impl(&pages.start, &pages.end).0; // Make sure that we never jump the gap in the address space when flushing. let second_half_start = - Page::::containing_address(VirtAddr::new(0xffff_8000_0000_0000)); + Page::::containing_address(VirtAddr::new(0xffff_8000_0000_0000)); let count = if pages.start < second_half_start { let count_to_second_half = Page::steps_between_impl(&pages.start, &second_half_start).0; @@ -355,7 +357,7 @@ where } } else { unsafe { - flush_broadcast::( + flush_broadcast::( None, self.pcid, self.asid, @@ -389,8 +391,8 @@ impl fmt::Display for AsidOutOfRangeError { /// See `INVLPGB` in AMD64 Architecture Programmer's Manual Volume 3 #[inline] -unsafe fn flush_broadcast( - va_and_count: Option<(Page, u16)>, +unsafe fn flush_broadcast( + va_and_count: Option<(Page, u16)>, pcid: Option, asid: Option, include_global: bool, diff --git a/src/lib.rs b/src/lib.rs index 8153e6fea..499bf0a75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ #![warn(missing_docs)] #![deny(missing_debug_implementations)] #![deny(unsafe_op_in_unsafe_fn)] +#![feature(const_trait_impl)] pub use crate::addr::{align_down, align_up, PhysAddr, VirtAddr}; diff --git a/src/registers/control.rs b/src/registers/control.rs index 896fd8f9e..421439d47 100644 --- a/src/registers/control.rs +++ b/src/registers/control.rs @@ -230,8 +230,10 @@ impl PriorityClass { mod x86_64 { use super::*; use crate::{ - addr::VirtAddrNotValid, instructions::tlb::Pcid, structures::paging::PhysFrame, PhysAddr, - VirtAddr, + addr::{VirtAddrNotValid, VirtValidity}, + instructions::tlb::Pcid, + structures::paging::PhysFrame, + PhysAddr, VirtAddr, }; use core::arch::asm; @@ -317,7 +319,7 @@ mod x86_64 { /// This method returns a [`VirtAddrNotValid`] error if the CR2 register contains a /// non-canonical address. Call [`Cr2::read_raw`] to handle such cases. #[inline] - pub fn read() -> Result { + pub fn read() -> Result, VirtAddrNotValid> { VirtAddr::try_new(Self::read_raw()) } diff --git a/src/registers/model_specific.rs b/src/registers/model_specific.rs index a694615dd..17ffdc24e 100644 --- a/src/registers/model_specific.rs +++ b/src/registers/model_specific.rs @@ -194,6 +194,7 @@ bitflags! { mod x86_64 { use super::*; use crate::addr::VirtAddr; + use crate::addr::VirtValidity; use crate::registers::rflags::RFlags; use crate::structures::gdt::SegmentSelector; use crate::structures::paging::Page; @@ -330,7 +331,7 @@ mod x86_64 { /// If [`CR4.FSGSBASE`][Cr4Flags::FSGSBASE] is set, the more efficient /// [`FS::read_base`] can be used instead. #[inline] - pub fn read() -> VirtAddr { + pub fn read() -> VirtAddr { VirtAddr::new(unsafe { Self::MSR.read() }) } @@ -339,7 +340,7 @@ mod x86_64 { /// If [`CR4.FSGSBASE`][Cr4Flags::FSGSBASE] is set, the more efficient /// [`FS::write_base`] can be used instead. #[inline] - pub fn write(address: VirtAddr) { + pub fn write(address: VirtAddr) { let mut msr = Self::MSR; unsafe { msr.write(address.as_u64()) }; } @@ -351,7 +352,7 @@ mod x86_64 { /// If [`CR4.FSGSBASE`][Cr4Flags::FSGSBASE] is set, the more efficient /// [`GS::read_base`] can be used instead. #[inline] - pub fn read() -> VirtAddr { + pub fn read() -> VirtAddr { VirtAddr::new(unsafe { Self::MSR.read() }) } @@ -360,7 +361,7 @@ mod x86_64 { /// If [`CR4.FSGSBASE`][Cr4Flags::FSGSBASE] is set, the more efficient /// [`GS::write_base`] can be used instead. #[inline] - pub fn write(address: VirtAddr) { + pub fn write(address: VirtAddr) { let mut msr = Self::MSR; unsafe { msr.write(address.as_u64()) }; } @@ -369,13 +370,13 @@ mod x86_64 { impl KernelGsBase { /// Read the current KernelGsBase register. #[inline] - pub fn read() -> VirtAddr { + pub fn read() -> VirtAddr { VirtAddr::new(unsafe { Self::MSR.read() }) } /// Write a given virtual address to the KernelGsBase register. #[inline] - pub fn write(address: VirtAddr) { + pub fn write(address: VirtAddr) { let mut msr = Self::MSR; unsafe { msr.write(address.as_u64()) }; } @@ -518,14 +519,14 @@ mod x86_64 { /// Read the current LStar register. /// This holds the target RIP of a syscall. #[inline] - pub fn read() -> VirtAddr { + pub fn read() -> VirtAddr { VirtAddr::new(unsafe { Self::MSR.read() }) } /// Write a given virtual address to the LStar register. /// This holds the target RIP of a syscall. #[inline] - pub fn write(address: VirtAddr) { + pub fn write(address: VirtAddr) { let mut msr = Self::MSR; unsafe { msr.write(address.as_u64()) }; } diff --git a/src/registers/segmentation.rs b/src/registers/segmentation.rs index 5d5954bc8..4d65220e4 100644 --- a/src/registers/segmentation.rs +++ b/src/registers/segmentation.rs @@ -1,7 +1,7 @@ //! Abstractions for segment registers. use super::model_specific::Msr; -use crate::{PrivilegeLevel, VirtAddr}; +use crate::{addr::VirtValidity, PrivilegeLevel, VirtAddr}; use bit_field::BitField; use core::fmt; // imports for intra doc links @@ -45,7 +45,7 @@ pub trait Segment64: Segment { /// ## Exceptions /// /// If [`CR4.FSGSBASE`][Cr4Flags::FSGSBASE] is not set, this instruction will throw a `#UD`. - fn read_base() -> VirtAddr; + fn read_base() -> VirtAddr; /// Writes the segment base address /// /// ## Exceptions @@ -56,7 +56,7 @@ pub trait Segment64: Segment { /// /// The caller must ensure that this write operation has no unsafe side /// effects, as the segment base address might be in use. - unsafe fn write_base(base: VirtAddr); + unsafe fn write_base(base: VirtAddr); } /// Specifies which element to load into a segment from diff --git a/src/structures/idt.rs b/src/structures/idt.rs index f15cedcce..1cf4227fe 100644 --- a/src/structures/idt.rs +++ b/src/structures/idt.rs @@ -20,6 +20,7 @@ //! //! These types are defined for the compatibility with the Nightly Rust build. +use crate::addr::VirtValidity; use crate::registers::rflags::RFlags; use crate::{PrivilegeLevel, VirtAddr}; use bit_field::BitField; @@ -1015,16 +1016,16 @@ impl EntryOptions { /// occurs, which can cause undefined behavior (see the [`as_mut`](InterruptStackFrame::as_mut) /// method for more information). #[repr(transparent)] -pub struct InterruptStackFrame(InterruptStackFrameValue); +pub struct InterruptStackFrame(InterruptStackFrameValue); -impl InterruptStackFrame { +impl InterruptStackFrame { /// Creates a new interrupt stack frame with the given values. #[inline] pub fn new( - instruction_pointer: VirtAddr, + instruction_pointer: VirtAddr, code_segment: SegmentSelector, cpu_flags: RFlags, - stack_pointer: VirtAddr, + stack_pointer: VirtAddr, stack_segment: SegmentSelector, ) -> Self { Self(InterruptStackFrameValue::new( @@ -1051,13 +1052,13 @@ impl InterruptStackFrame { /// Also, it is not fully clear yet whether modifications of the interrupt stack frame are /// officially supported by LLVM's x86 interrupt calling convention. #[inline] - pub unsafe fn as_mut(&mut self) -> Volatile<&mut InterruptStackFrameValue> { + pub unsafe fn as_mut(&mut self) -> Volatile<&mut InterruptStackFrameValue> { Volatile::new(&mut self.0) } } -impl Deref for InterruptStackFrame { - type Target = InterruptStackFrameValue; +impl Deref for InterruptStackFrame { + type Target = InterruptStackFrameValue; #[inline] fn deref(&self) -> &Self::Target { @@ -1065,7 +1066,7 @@ impl Deref for InterruptStackFrame { } } -impl fmt::Debug for InterruptStackFrame { +impl fmt::Debug for InterruptStackFrame { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.0.fmt(f) @@ -1075,33 +1076,33 @@ impl fmt::Debug for InterruptStackFrame { /// Represents the interrupt stack frame pushed by the CPU on interrupt or exception entry. #[derive(Clone, Copy)] #[repr(C)] -pub struct InterruptStackFrameValue { +pub struct InterruptStackFrameValue { /// This value points to the instruction that should be executed when the interrupt /// handler returns. For most interrupts, this value points to the instruction immediately /// following the last executed instruction. However, for some exceptions (e.g., page faults), /// this value points to the faulting instruction, so that the instruction is restarted on /// return. See the documentation of the [`InterruptDescriptorTable`] fields for more details. - pub instruction_pointer: VirtAddr, + pub instruction_pointer: VirtAddr, /// The code segment selector at the time of the interrupt. pub code_segment: SegmentSelector, _reserved1: [u8; 6], /// The flags register before the interrupt handler was invoked. pub cpu_flags: RFlags, /// The stack pointer at the time of the interrupt. - pub stack_pointer: VirtAddr, + pub stack_pointer: VirtAddr, /// The stack segment descriptor at the time of the interrupt (often zero in 64-bit mode). pub stack_segment: SegmentSelector, _reserved2: [u8; 6], } -impl InterruptStackFrameValue { +impl InterruptStackFrameValue { /// Creates a new interrupt stack frame with the given values. #[inline] pub fn new( - instruction_pointer: VirtAddr, + instruction_pointer: VirtAddr, code_segment: SegmentSelector, cpu_flags: RFlags, - stack_pointer: VirtAddr, + stack_pointer: VirtAddr, stack_segment: SegmentSelector, ) -> Self { Self { @@ -1148,7 +1149,7 @@ impl InterruptStackFrameValue { } } -impl fmt::Debug for InterruptStackFrameValue { +impl fmt::Debug for InterruptStackFrameValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut s = f.debug_struct("InterruptStackFrame"); s.field("instruction_pointer", &self.instruction_pointer); diff --git a/src/structures/mod.rs b/src/structures/mod.rs index 084bbafab..6d6dbbfed 100644 --- a/src/structures/mod.rs +++ b/src/structures/mod.rs @@ -1,6 +1,6 @@ //! Representations of various x86 specific structures and descriptor tables. -use crate::VirtAddr; +use crate::{addr::VirtValidity, VirtAddr}; pub mod gdt; @@ -16,11 +16,11 @@ pub mod tss; /// This is in a format suitable for giving to 'lgdt' or 'lidt'. #[derive(Debug, Clone, Copy)] #[repr(C, packed(2))] -pub struct DescriptorTablePointer { +pub struct DescriptorTablePointer { /// Size of the DT in bytes - 1. pub limit: u16, /// Pointer to the memory region containing the DT. - pub base: VirtAddr, + pub base: VirtAddr, } #[cfg(test)] diff --git a/src/structures/paging/mapper/mapped_page_table.rs b/src/structures/paging/mapper/mapped_page_table.rs index 5f673c55c..14f834d7c 100644 --- a/src/structures/paging/mapper/mapped_page_table.rs +++ b/src/structures/paging/mapper/mapped_page_table.rs @@ -1,7 +1,12 @@ -use crate::structures::paging::{ - mapper::*, - page::AddressNotAligned, - page_table::{FrameError, PageTable, PageTableEntry, PageTableLevel}, +use core::marker::PhantomData; + +use crate::{ + addr::VirtValidity, + structures::paging::{ + mapper::*, + page::AddressNotAligned, + page_table::{FrameError, PageTable, PageTableEntry, PageTableLevel}, + }, }; /// A Mapper implementation that relies on a PhysAddr to VirtAddr conversion function. diff --git a/src/structures/paging/mapper/mod.rs b/src/structures/paging/mapper/mod.rs index d0f217167..eb54c0fda 100644 --- a/src/structures/paging/mapper/mod.rs +++ b/src/structures/paging/mapper/mod.rs @@ -6,11 +6,14 @@ pub use self::offset_page_table::OffsetPageTable; #[cfg(all(feature = "instructions", target_arch = "x86_64"))] pub use self::recursive_page_table::{InvalidPageTable, RecursivePageTable}; -use crate::structures::paging::{ - frame_alloc::{FrameAllocator, FrameDeallocator}, - page::PageRangeInclusive, - page_table::PageTableFlags, - Page, PageSize, PhysFrame, Size1GiB, Size2MiB, Size4KiB, +use crate::{ + addr::VirtValidity, + structures::paging::{ + frame_alloc::{FrameAllocator, FrameDeallocator}, + page::PageRangeInclusive, + page_table::PageTableFlags, + Page, PageSize, PhysFrame, Size1GiB, Size2MiB, Size4KiB, + }, }; use crate::{PhysAddr, VirtAddr}; @@ -25,7 +28,7 @@ pub trait MapperAllSizes: Mapper + Mapper + Mapper impl MapperAllSizes for T where T: Mapper + Mapper + Mapper {} /// Provides methods for translating virtual addresses. -pub trait Translate { +pub trait Translate { /// Return the frame that the given virtual address is mapped to and the offset within that /// frame. /// @@ -33,7 +36,7 @@ pub trait Translate { /// frame is returned. Otherwise an error value is returned. /// /// This function works with huge pages of all sizes. - fn translate(&self, addr: VirtAddr) -> TranslateResult; + fn translate(&self, addr: VirtAddr) -> TranslateResult; /// Translates the given virtual address to the physical address that it maps to. /// diff --git a/src/structures/paging/page.rs b/src/structures/paging/page.rs index 9bfe22cfb..6bf894217 100644 --- a/src/structures/paging/page.rs +++ b/src/structures/paging/page.rs @@ -1,5 +1,6 @@ //! Abstractions for default-sized and huge virtual memory pages. +use crate::addr::VirtValidity; use crate::sealed::Sealed; use crate::structures::paging::page_table::PageTableLevel; use crate::structures::paging::PageTableIndex; @@ -64,12 +65,12 @@ impl Sealed for super::Size1GiB {} /// A virtual memory page. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(C)] -pub struct Page { - start_address: VirtAddr, +pub struct Page { + start_address: VirtAddr, size: PhantomData, } -impl Page { +impl Page { /// The page size in bytes. pub const SIZE: u64 = S::SIZE; @@ -78,7 +79,10 @@ impl Page { /// Returns an error if the address is not correctly aligned (i.e. is not a valid page start). #[inline] #[rustversion::attr(since(1.61), const)] - pub fn from_start_address(address: VirtAddr) -> Result { + pub fn from_start_address(address: VirtAddr) -> Result + where + V: [const] VirtValidity, + { if !address.is_aligned_u64(S::SIZE) { return Err(AddressNotAligned); } @@ -92,7 +96,10 @@ impl Page { /// The address must be correctly aligned. #[inline] #[rustversion::attr(since(1.61), const)] - pub unsafe fn from_start_address_unchecked(start_address: VirtAddr) -> Self { + pub unsafe fn from_start_address_unchecked(start_address: VirtAddr) -> Self + where + V: [const] VirtValidity, + { Page { start_address, size: PhantomData, @@ -102,7 +109,10 @@ impl Page { /// Returns the page that contains the given virtual address. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn containing_address(address: VirtAddr) -> Self { + pub fn containing_address(address: VirtAddr) -> Self + where + V: [const] VirtValidity, + { Page { start_address: address.align_down_u64(S::SIZE), size: PhantomData, @@ -112,7 +122,10 @@ impl Page { /// Returns the start address of the page. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn start_address(self) -> VirtAddr { + pub fn start_address(self) -> VirtAddr + where + V: [const] VirtValidity, + { self.start_address } @@ -126,35 +139,44 @@ impl Page { /// Returns the level 4 page table index of this page. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn p4_index(self) -> PageTableIndex { + pub fn p4_index(self) -> PageTableIndex + where + V: [const] VirtValidity, + { self.start_address().p4_index() } /// Returns the level 3 page table index of this page. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn p3_index(self) -> PageTableIndex { + pub fn p3_index(self) -> PageTableIndex + where + V: [const] VirtValidity, + { self.start_address().p3_index() } /// Returns the table index of this page at the specified level. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn page_table_index(self, level: PageTableLevel) -> PageTableIndex { + pub fn page_table_index(self, level: PageTableLevel) -> PageTableIndex + where + V: [const] VirtValidity, + { self.start_address().page_table_index(level) } /// Returns a range of pages, exclusive `end`. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn range(start: Self, end: Self) -> PageRange { + pub fn range(start: Self, end: Self) -> PageRange { PageRange { start, end } } /// Returns a range of pages, inclusive `end`. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn range_inclusive(start: Self, end: Self) -> PageRangeInclusive { + pub fn range_inclusive(start: Self, end: Self) -> PageRangeInclusive { PageRangeInclusive { start, end } } @@ -188,23 +210,26 @@ impl Page { } } -impl Page { +impl Page { /// Returns the level 2 page table index of this page. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn p2_index(self) -> PageTableIndex { + pub fn p2_index(self) -> PageTableIndex + where + V: [const] VirtValidity, + { self.start_address().p2_index() } } -impl Page { +impl Page { /// Returns the 1GiB memory page with the specified page table indices. #[inline] #[rustversion::attr(since(1.61), const)] - pub fn from_page_table_indices_1gib( - p4_index: PageTableIndex, - p3_index: PageTableIndex, - ) -> Self { + pub fn from_page_table_indices_1gib(p4_index: PageTableIndex, p3_index: PageTableIndex) -> Self + where + V: [const] VirtValidity, + { let mut addr = 0; addr |= p4_index.into_u64() << 39; addr |= p3_index.into_u64() << 30; @@ -212,7 +237,7 @@ impl Page { } } -impl Page { +impl Page { /// Returns the 2MiB memory page with the specified page table indices. #[inline] #[rustversion::attr(since(1.61), const)] @@ -220,7 +245,10 @@ impl Page { p4_index: PageTableIndex, p3_index: PageTableIndex, p2_index: PageTableIndex, - ) -> Self { + ) -> Self + where + V: [const] VirtValidity, + { let mut addr = 0; addr |= p4_index.into_u64() << 39; addr |= p3_index.into_u64() << 30; @@ -229,7 +257,7 @@ impl Page { } } -impl Page { +impl Page { /// Returns the 4KiB memory page with the specified page table indices. #[inline] #[rustversion::attr(since(1.61), const)] @@ -238,7 +266,10 @@ impl Page { p3_index: PageTableIndex, p2_index: PageTableIndex, p1_index: PageTableIndex, - ) -> Self { + ) -> Self + where + V: [const] VirtValidity, + { let mut addr = 0; addr |= p4_index.into_u64() << 39; addr |= p3_index.into_u64() << 30; @@ -254,7 +285,7 @@ impl Page { } } -impl fmt::Debug for Page { +impl fmt::Debug for Page { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_fmt(format_args!( "Page[{}]({:#x})", @@ -264,7 +295,7 @@ impl fmt::Debug for Page { } } -impl Add for Page { +impl Add for Page { type Output = Self; #[inline] fn add(self, rhs: u64) -> Self::Output { @@ -272,14 +303,14 @@ impl Add for Page { } } -impl AddAssign for Page { +impl AddAssign for Page { #[inline] fn add_assign(&mut self, rhs: u64) { *self = *self + rhs; } } -impl Sub for Page { +impl Sub for Page { type Output = Self; #[inline] fn sub(self, rhs: u64) -> Self::Output { @@ -287,14 +318,14 @@ impl Sub for Page { } } -impl SubAssign for Page { +impl SubAssign for Page { #[inline] fn sub_assign(&mut self, rhs: u64) { *self = *self - rhs; } } -impl Sub for Page { +impl Sub for Page { type Output = u64; #[inline] fn sub(self, rhs: Self) -> Self::Output { @@ -303,7 +334,7 @@ impl Sub for Page { } #[cfg(feature = "step_trait")] -impl Step for Page { +impl Step for Page { fn steps_between(start: &Self, end: &Self) -> (usize, Option) { Self::steps_between_impl(start, end) } @@ -327,14 +358,14 @@ impl Step for Page { /// A range of pages with exclusive upper bound. #[derive(Clone, Copy, PartialEq, Eq, Hash)] #[repr(C)] -pub struct PageRange { +pub struct PageRange { /// The start of the range, inclusive. - pub start: Page, + pub start: Page, /// The end of the range, exclusive. - pub end: Page, + pub end: Page, } -impl PageRange { +impl PageRange { /// Returns whether this range contains no pages. #[inline] pub fn is_empty(&self) -> bool { @@ -358,8 +389,8 @@ impl PageRange { } } -impl Iterator for PageRange { - type Item = Page; +impl Iterator for PageRange { + type Item = Page; #[inline] fn next(&mut self) -> Option { @@ -373,10 +404,10 @@ impl Iterator for PageRange { } } -impl PageRange { +impl PageRange { /// Converts the range of 2MiB pages to a range of 4KiB pages. #[inline] - pub fn as_4kib_page_range(self) -> PageRange { + pub fn as_4kib_page_range(self) -> PageRange { PageRange { start: Page::containing_address(self.start.start_address()), end: Page::containing_address(self.end.start_address()), @@ -384,7 +415,7 @@ impl PageRange { } } -impl fmt::Debug for PageRange { +impl fmt::Debug for PageRange { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("PageRange") .field("start", &self.start) @@ -396,14 +427,14 @@ impl fmt::Debug for PageRange { /// A range of pages with inclusive upper bound. #[derive(Clone, Copy, PartialEq, Eq, Hash)] #[repr(C)] -pub struct PageRangeInclusive { +pub struct PageRangeInclusive { /// The start of the range, inclusive. - pub start: Page, + pub start: Page, /// The end of the range, inclusive. - pub end: Page, + pub end: Page, } -impl PageRangeInclusive { +impl PageRangeInclusive { /// Returns whether this range contains no pages. #[inline] pub fn is_empty(&self) -> bool { @@ -427,8 +458,8 @@ impl PageRangeInclusive { } } -impl Iterator for PageRangeInclusive { - type Item = Page; +impl Iterator for PageRangeInclusive { + type Item = Page; #[inline] fn next(&mut self) -> Option { @@ -451,7 +482,7 @@ impl Iterator for PageRangeInclusive { } } -impl fmt::Debug for PageRangeInclusive { +impl fmt::Debug for PageRangeInclusive { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("PageRangeInclusive") .field("start", &self.start) diff --git a/src/structures/tss.rs b/src/structures/tss.rs index f0174f3f4..1607b8b10 100644 --- a/src/structures/tss.rs +++ b/src/structures/tss.rs @@ -1,6 +1,6 @@ //! Provides a type for the task state segment structure. -use crate::VirtAddr; +use crate::{addr::VirtValidity, VirtAddr}; use core::{ fmt::{self, Display}, mem::size_of, @@ -11,15 +11,15 @@ use core::{ /// but is used for stack switching when an interrupt or exception occurs. #[derive(Debug, Clone, Copy)] #[repr(C, packed(4))] -pub struct TaskStateSegment { +pub struct TaskStateSegment { reserved_1: u32, /// The full 64-bit canonical forms of the stack pointers (RSP) for privilege levels 0-2. /// The stack pointers used when a privilege level change occurs from a lower privilege level to a higher one. - pub privilege_stack_table: [VirtAddr; 3], + pub privilege_stack_table: [VirtAddr; 3], reserved_2: u64, /// The full 64-bit canonical forms of the interrupt stack table (IST) pointers. /// The stack pointers used when an entry in the Interrupt Descriptor Table has an IST value other than 0. - pub interrupt_stack_table: [VirtAddr; 7], + pub interrupt_stack_table: [VirtAddr; 7], reserved_3: u64, reserved_4: u16, /// The 16-bit offset to the I/O permission bit map from the 64-bit TSS base. It must not @@ -27,7 +27,7 @@ pub struct TaskStateSegment { pub iomap_base: u16, } -impl TaskStateSegment { +impl TaskStateSegment { /// Creates a new TSS with zeroed privilege and interrupt stack table and an /// empty I/O-Permission Bitmap. /// @@ -35,11 +35,14 @@ impl TaskStateSegment { /// `size_of::() - 1`, this means that `iomap_base` is /// initialized to `size_of::()`. #[inline] - pub const fn new() -> TaskStateSegment { + pub const fn new() -> Self + where + V: [const] VirtValidity, + { TaskStateSegment { privilege_stack_table: [VirtAddr::zero(); 3], interrupt_stack_table: [VirtAddr::zero(); 7], - iomap_base: size_of::() as u16, + iomap_base: size_of::() as u16, reserved_1: 0, reserved_2: 0, reserved_3: 0, @@ -48,7 +51,7 @@ impl TaskStateSegment { } } -impl Default for TaskStateSegment { +impl Default for TaskStateSegment { #[inline] fn default() -> Self { Self::new()