Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions ctutils/src/byteutils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//! Utilities for leveraging the optimized implementations of `Cmov`/`CmovEq` for types whose size
//! is 1-byte.

use crate::{Choice, traits::no_uninit::NoUninit};
use cmov::CmovEq;
use core::slice;

/// Perform constant-time equality comparison on slices of 1-byte sized types using the optimized
/// implementation of `CmovEq` for byte slices.
pub(crate) fn ct_eq<T: NoUninit>(a: &[T], b: &[T]) -> Choice {
assert_eq!(
size_of::<T>(),
1,
"this function is intended for 1-byte sized types"
);

// SAFETY:
// - We asserted above that `size_of::<T>() == size_of::<u8>() == 1`.
// - The `NoUninit` bound ensures the type does not contain uninitialized memory.
// - We don't need to worry about alignment because all types are 1-byte.
// - 1-byte is too small to contain a pointer/reference.
// - We source the slice length directly from the other valid slice.
#[allow(unsafe_code)]
let (a, b) = unsafe {
(
slice::from_raw_parts(a.as_ptr() as *const u8, a.len()),
slice::from_raw_parts(b.as_ptr() as *const u8, b.len()),
)
};

let mut ret = Choice::FALSE;
a.cmoveq(b, 1, &mut ret.0);
ret
}

#[cfg(test)]
mod tests {
use core::num::{NonZeroI8, NonZeroU8};

macro_rules! ct_eq_test {
($name:ident, $a:expr, $b:expr) => {
#[test]
fn $name() {
let x = $a;
let y = $b;

let a = [x, x, x];
let b = [x, x, y];
let c = [x, y, y];
let d = [y, y, y];

assert!(super::ct_eq(&a, &a).to_bool());
assert!(super::ct_eq(&b, &b).to_bool());
assert!(super::ct_eq(&c, &c).to_bool());
assert!(super::ct_eq(&d, &d).to_bool());

for rhs in &[b, c, d] {
assert!(!super::ct_eq(&a, rhs).to_bool());
}
}
};
}

ct_eq_test!(i8_ct_eq, 1i8, 2i8);
ct_eq_test!(u8_ct_eq, 1u8, 2u8);
ct_eq_test!(
non_zero_i8_ct_eq,
NonZeroI8::new(1i8).unwrap(),
NonZeroI8::new(2i8).unwrap()
);
ct_eq_test!(
non_zero_u8_ct_eq,
NonZeroU8::new(1u8).unwrap(),
NonZeroU8::new(2u8).unwrap()
);
}
1 change: 1 addition & 0 deletions ctutils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
extern crate alloc;

mod bytes;
mod byteutils;
mod choice;
mod ct_option;
mod traits;
Expand Down
1 change: 1 addition & 0 deletions ctutils/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pub(crate) mod ct_lookup;
pub(crate) mod ct_lt;
pub(crate) mod ct_neg;
pub(crate) mod ct_select;
pub(crate) mod no_uninit;
58 changes: 32 additions & 26 deletions ctutils/src/traits/ct_eq.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Choice;
use crate::{Choice, byteutils, traits::no_uninit::NoUninit};
use cmov::CmovEq;
use core::{
cmp,
Expand Down Expand Up @@ -109,37 +109,43 @@ impl CtEq for cmp::Ordering {
}
}

impl<T: CtEq> CtEq for [T] {
impl<T> CtEq for [T]
where
T: CtEq + NoUninit,
{
#[inline]
fn ct_eq(&self, other: &[T]) -> Choice {
const {
assert!(
size_of::<T>() != 1,
"use `BytesCtEq::bytes_ct_eq` when working with byte-sized values"
);
/// Iterate over every element in `a_slice` and `b_slice` comparing elements.
fn ct_eq_inner<T: CtEq>(a_slice: &[T], b_slice: &[T]) -> Choice {
let mut ret = a_slice.len().ct_eq(&b_slice.len());
for (a, b) in a_slice.iter().zip(b_slice.iter()) {
ret &= a.ct_eq(b);
}
ret
}

let mut ret = self.len().ct_eq(&other.len());
for (a, b) in self.iter().zip(other.iter()) {
ret &= a.ct_eq(b);
}
ret
}
if const { size_of::<T>() == 1 } {
let ret = byteutils::ct_eq(self, other);

#[inline]
fn ct_ne(&self, other: &[T]) -> Choice {
const {
assert!(
size_of::<T>() != 1,
"use `BytesCtEq::bytes_ct_ne` when working with byte-sized values"
// Double-check the result we get from `byteutils::ct_eq` is the same one we would've
// gotten had we invoked `CtEq` on each element (i.e. the slow path)
debug_assert_eq!(
ret.to_bool(),
ct_eq_inner(self, other).to_bool(),
"mismatch between fast and slow ct_eq implementations"
);
}

!self.ct_eq(other)
ret
} else {
ct_eq_inner(self, other)
}
}
}

impl<T: CtEq, const N: usize> CtEq for [T; N] {
impl<T, const N: usize> CtEq for [T; N]
where
T: CtEq + NoUninit,
{
#[inline]
fn ct_eq(&self, other: &[T; N]) -> Choice {
self.as_slice().ct_eq(other.as_slice())
Expand All @@ -161,7 +167,7 @@ where
#[cfg(feature = "alloc")]
impl<T> CtEq for Box<[T]>
where
T: CtEq,
T: CtEq + NoUninit,
{
#[inline]
#[track_caller]
Expand All @@ -173,7 +179,7 @@ where
#[cfg(feature = "alloc")]
impl<T> CtEq<[T]> for Box<[T]>
where
T: CtEq,
T: CtEq + NoUninit,
{
#[inline]
#[track_caller]
Expand All @@ -185,7 +191,7 @@ where
#[cfg(feature = "alloc")]
impl<T> CtEq for Vec<T>
where
T: CtEq,
T: CtEq + NoUninit,
{
#[inline]
#[track_caller]
Expand All @@ -197,7 +203,7 @@ where
#[cfg(feature = "alloc")]
impl<T> CtEq<[T]> for Vec<T>
where
T: CtEq,
T: CtEq + NoUninit,
{
#[inline]
#[track_caller]
Expand Down
65 changes: 65 additions & 0 deletions ctutils/src/traits/no_uninit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#![allow(
clippy::missing_safety_doc,
clippy::undocumented_unsafe_blocks,
unsafe_code
)]

use crate::{Choice, CtOption};
use core::{
cmp,
num::{
NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroU8, NonZeroU16,
NonZeroU32, NonZeroU64, NonZeroU128,
},
};

#[cfg(feature = "alloc")]
use alloc::{boxed::Box, vec::Vec};

/// Marker trait for types which do not contain uninitialized memory.
pub unsafe trait NoUninit {}

// Impl `NoUninit` for the given type
macro_rules! impl_no_uninit {
( $($ty:ty),+ ) => {
$(
unsafe impl NoUninit for $ty {}
)+
};
}

impl_no_uninit!(
i8,
i16,
i32,
i64,
i128,
isize,
u8,
u16,
u32,
u64,
u128,
usize,
NonZeroI8,
NonZeroI16,
NonZeroI32,
NonZeroI64,
NonZeroI128,
NonZeroU8,
NonZeroU16,
NonZeroU32,
NonZeroU64,
NonZeroU128
);

unsafe impl NoUninit for Choice {}
unsafe impl NoUninit for cmp::Ordering {}
unsafe impl<T: NoUninit> NoUninit for CtOption<T> {}
unsafe impl<T: NoUninit> NoUninit for [T] {}
unsafe impl<T: NoUninit, const N: usize> NoUninit for [T; N] {}

#[cfg(feature = "alloc")]
unsafe impl<T: NoUninit> NoUninit for Box<T> {}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Box it should be fine since I think it's guaranteed that it's just a pointer, but you don't even need the T: NoUninit bound here :). the implicit T: Sized matters though since fat pointer layout isn't guaranteed

#[cfg(feature = "alloc")]
unsafe impl<T: NoUninit> NoUninit for Vec<T> {}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This impl is not sound, vec does not have layout guarantees and there could be padding under -Zrandomize-layout for example