Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 71 additions & 25 deletions src/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

use core::convert::TryFrom;
use core::fmt;
use core::hash::Hash;
#[cfg(feature = "step_trait")]
use core::iter::Step;
use core::marker::PhantomData;
use core::ops::{Add, AddAssign, Sub, SubAssign};
#[cfg(feature = "memory_encryption")]
use core::sync::atomic::Ordering;
Expand All @@ -30,7 +32,33 @@ const ADDRESS_SPACE_SIZE: u64 = 0x1_0000_0000_0000;
/// are called “canonical”. This type guarantees that it always represents a canonical address.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct VirtAddr(u64);
pub struct VirtAddr<T: VirtValidity>(u64, PhantomData<T>);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set the default for T to Virt48?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's on my TODO list.


pub const trait VirtValidity: Clone + Copy + Eq + Ord + Hash {
fn truncate(addr: u64) -> u64;
}

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Virt48 {}

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Virt57 {}

impl const VirtValidity for Virt48 {
fn truncate(addr: u64) -> u64 {
// By doing the right shift as a signed operation (on a i64), it will
// sign extend the value, repeating the leftmost bit.
((addr << 16) as i64 >> 16) as u64
}
}

impl const VirtValidity for Virt57 {
fn truncate(addr: u64) -> u64 {
// By doing the right shift as a signed operation (on a i64), it will
// sign extend the value, repeating the leftmost bit.
((addr << 7) as i64 >> 7) as u64
}
}

/// A 64-bit physical memory address.
///
Expand Down Expand Up @@ -63,7 +91,7 @@ impl core::fmt::Debug for VirtAddrNotValid {
}
}

impl VirtAddr {
impl<V: VirtValidity> VirtAddr<V> {
/// Creates a new canonical virtual address.
///
/// The provided address should already be canonical. If you want to check
Expand All @@ -74,7 +102,10 @@ impl VirtAddr {
/// This function panics if the bits in the range 48 to 64 are invalid
/// (i.e. are not a proper sign extension of bit 47).
#[inline]
pub const fn new(addr: u64) -> VirtAddr {
pub const fn new(addr: u64) -> Self
where
V: [const] VirtValidity,
{
// TODO: Replace with .ok().expect(msg) when that works on stable.
match Self::try_new(addr) {
Ok(v) => v,
Expand All @@ -89,7 +120,10 @@ impl VirtAddr {
/// if bits 48 to 64 are a correct sign
/// extension (i.e. copies of bit 47).
#[inline]
pub const fn try_new(addr: u64) -> Result<VirtAddr, VirtAddrNotValid> {
pub const fn try_new(addr: u64) -> Result<Self, VirtAddrNotValid>
where
V: [const] VirtValidity,
{
let v = Self::new_truncate(addr);
if v.0 == addr {
Ok(v)
Expand All @@ -104,10 +138,13 @@ impl VirtAddr {
/// canonical, overwriting bits 48 to 64. If you want to check whether an
/// address is canonical, use [`new`](Self::new) or [`try_new`](Self::try_new).
#[inline]
pub const fn new_truncate(addr: u64) -> VirtAddr {
pub const fn new_truncate(addr: u64) -> Self
where
V: [const] VirtValidity,
{
// By doing the right shift as a signed operation (on a i64), it will
// sign extend the value, repeating the leftmost bit.
VirtAddr(((addr << 16) as i64 >> 16) as u64)
Self(V::truncate(addr), PhantomData)
}

/// Creates a new virtual address, without any checks.
Expand All @@ -116,14 +153,17 @@ impl VirtAddr {
///
/// You must make sure bits 48..64 are equal to bit 47. This is not checked.
#[inline]
pub const unsafe fn new_unsafe(addr: u64) -> VirtAddr {
VirtAddr(addr)
pub const unsafe fn new_unsafe(addr: u64) -> Self {
Self(addr, PhantomData)
}

/// Creates a virtual address that points to `0`.
#[inline]
pub const fn zero() -> VirtAddr {
VirtAddr(0)
pub const fn zero() -> Self
where
V: [const] VirtValidity,
{
VirtAddr::new(0)
}

/// Converts the address to an `u64`.
Expand Down Expand Up @@ -190,7 +230,10 @@ impl VirtAddr {
///
/// See the `align_down` function for more information.
#[inline]
pub(crate) const fn align_down_u64(self, align: u64) -> Self {
pub(crate) const fn align_down_u64(self, align: u64) -> Self
where
V: [const] VirtValidity,
{
VirtAddr::new_truncate(align_down(self.0, align))
}

Expand All @@ -205,7 +248,10 @@ impl VirtAddr {

/// Checks whether the virtual address has the demanded alignment.
#[inline]
pub(crate) const fn is_aligned_u64(self, align: u64) -> bool {
pub(crate) const fn is_aligned_u64(self, align: u64) -> bool
where
V: [const] VirtValidity,
{
self.align_down_u64(align).as_u64() == self.as_u64()
}

Expand Down Expand Up @@ -325,50 +371,50 @@ impl VirtAddr {
}
}

impl fmt::Debug for VirtAddr {
impl<V: VirtValidity> fmt::Debug for VirtAddr<V> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("VirtAddr")
.field(&format_args!("{:#x}", self.0))
.finish()
}
}

impl fmt::Binary for VirtAddr {
impl<V: VirtValidity> fmt::Binary for VirtAddr<V> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Binary::fmt(&self.0, f)
}
}

impl fmt::LowerHex for VirtAddr {
impl<V: VirtValidity> fmt::LowerHex for VirtAddr<V> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::LowerHex::fmt(&self.0, f)
}
}

impl fmt::Octal for VirtAddr {
impl<V: VirtValidity> fmt::Octal for VirtAddr<V> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Octal::fmt(&self.0, f)
}
}

impl fmt::UpperHex for VirtAddr {
impl<V: VirtValidity> fmt::UpperHex for VirtAddr<V> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::UpperHex::fmt(&self.0, f)
}
}

impl fmt::Pointer for VirtAddr {
impl<V: VirtValidity> fmt::Pointer for VirtAddr<V> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Pointer::fmt(&(self.0 as *const ()), f)
}
}

impl Add<u64> for VirtAddr {
impl<V: VirtValidity> Add<u64> for VirtAddr<V> {
type Output = Self;

#[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))]
Expand All @@ -393,7 +439,7 @@ impl Add<u64> for VirtAddr {
}
}

impl AddAssign<u64> for VirtAddr {
impl<V: VirtValidity> AddAssign<u64> for VirtAddr<V> {
#[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))]
/// Add an offset to a virtual address.
///
Expand All @@ -411,7 +457,7 @@ impl AddAssign<u64> for VirtAddr {
}
}

impl Sub<u64> for VirtAddr {
impl<V: VirtValidity> Sub<u64> for VirtAddr<V> {
type Output = Self;

#[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))]
Expand All @@ -436,7 +482,7 @@ impl Sub<u64> for VirtAddr {
}
}

impl SubAssign<u64> for VirtAddr {
impl<V: VirtValidity> SubAssign<u64> for VirtAddr<V> {
#[cfg_attr(not(feature = "step_trait"), allow(rustdoc::broken_intra_doc_links))]
/// Subtract an offset from a virtual address.
///
Expand All @@ -454,7 +500,7 @@ impl SubAssign<u64> for VirtAddr {
}
}

impl Sub<VirtAddr> for VirtAddr {
impl<V: VirtValidity> Sub<VirtAddr<V>> for VirtAddr<V> {
type Output = u64;

/// Returns the difference between two addresses.
Expand All @@ -463,15 +509,15 @@ impl Sub<VirtAddr> for VirtAddr {
///
/// This function will panic on overflow.
#[inline]
fn sub(self, rhs: VirtAddr) -> Self::Output {
fn sub(self, rhs: VirtAddr<V>) -> Self::Output {
self.as_u64()
.checked_sub(rhs.as_u64())
.expect("attempt to subtract with overflow")
}
}

#[cfg(feature = "step_trait")]
impl Step for VirtAddr {
impl<V: VirtValidity> Step for VirtAddr<V> {
#[inline]
fn steps_between(start: &Self, end: &Self) -> (usize, Option<usize>) {
Self::steps_between_impl(start, end)
Expand Down
4 changes: 3 additions & 1 deletion src/instructions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub mod tlb;

use core::arch::asm;

use crate::addr::VirtValidity;

/// Halts the CPU until the next interrupt arrives.
#[inline]
pub fn hlt() {
Expand Down Expand Up @@ -48,7 +50,7 @@ pub fn bochs_breakpoint() {
/// Gets the current instruction pointer. Note that this is only approximate as it requires a few
/// instructions to execute.
#[inline(always)]
pub fn read_rip() -> crate::VirtAddr {
pub fn read_rip<V: VirtValidity>() -> crate::VirtAddr<V> {
let rip: u64;
unsafe {
asm!("lea {}, [rip]", out(reg) rip, options(nostack, nomem, preserves_flags));
Expand Down
5 changes: 3 additions & 2 deletions src/instructions/segmentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

pub use crate::registers::segmentation::{Segment, Segment64, CS, DS, ES, FS, GS, SS};
use crate::{
addr::VirtValidity,
registers::model_specific::{FsBase, GsBase, Msr},
structures::gdt::SegmentSelector,
VirtAddr,
Expand Down Expand Up @@ -41,7 +42,7 @@ macro_rules! segment64_impl {
impl Segment64 for $type {
const BASE: Msr = <$base>::MSR;
#[inline]
fn read_base() -> VirtAddr {
fn read_base<V: VirtValidity>() -> VirtAddr<V> {
unsafe {
let val: u64;
asm!(concat!("rd", $name, "base {}"), out(reg) val, options(nomem, nostack, preserves_flags));
Expand All @@ -50,7 +51,7 @@ macro_rules! segment64_impl {
}

#[inline]
unsafe fn write_base(base: VirtAddr) {
unsafe fn write_base<V: VirtValidity>(base: VirtAddr<V>) {
unsafe{
asm!(concat!("wr", $name, "base {}"), in(reg) base.as_u64(), options(nostack, preserves_flags));
}
Expand Down
13 changes: 7 additions & 6 deletions src/instructions/tables.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Functions to load GDT, IDT, and TSS structures.

use crate::addr::VirtValidity;
use crate::structures::gdt::SegmentSelector;
use crate::VirtAddr;
use core::arch::asm;
Expand All @@ -18,7 +19,7 @@ pub use crate::structures::DescriptorTablePointer;
/// `DescriptorTablePointer` points to a valid GDT and that loading this
/// GDT is safe.
#[inline]
pub unsafe fn lgdt(gdt: &DescriptorTablePointer) {
pub unsafe fn lgdt<V: VirtValidity>(gdt: &DescriptorTablePointer<V>) {
unsafe {
asm!("lgdt [{}]", in(reg) gdt, options(readonly, nostack, preserves_flags));
}
Expand All @@ -36,16 +37,16 @@ pub unsafe fn lgdt(gdt: &DescriptorTablePointer) {
/// `DescriptorTablePointer` points to a valid IDT and that loading this
/// IDT is safe.
#[inline]
pub unsafe fn lidt(idt: &DescriptorTablePointer) {
pub unsafe fn lidt<V: VirtValidity>(idt: &DescriptorTablePointer<V>) {
unsafe {
asm!("lidt [{}]", in(reg) idt, options(readonly, nostack, preserves_flags));
}
}

/// Get the address of the current GDT.
#[inline]
pub fn sgdt() -> DescriptorTablePointer {
let mut gdt: DescriptorTablePointer = DescriptorTablePointer {
pub fn sgdt<V: VirtValidity>() -> DescriptorTablePointer<V> {
let mut gdt = DescriptorTablePointer {
limit: 0,
base: VirtAddr::new(0),
};
Expand All @@ -57,8 +58,8 @@ pub fn sgdt() -> DescriptorTablePointer {

/// Get the address of the current IDT.
#[inline]
pub fn sidt() -> DescriptorTablePointer {
let mut idt: DescriptorTablePointer = DescriptorTablePointer {
pub fn sidt<V: VirtValidity>() -> DescriptorTablePointer<V> {
let mut idt = DescriptorTablePointer {
limit: 0,
base: VirtAddr::new(0),
};
Expand Down
Loading
Loading