Skip to content

Commit 62bb328

Browse files
committed
[WIP] ctutils: automatically use optimized 1-byte CtEq
As an alternative to the `BytesCtEq` approach to using the optimized impl of `CmovEq` for `[u8]`, this implementation automatically uses it for all 1-byte types when the `CtEq` impl for `[T]` is invoked (or anything that calls it, like the impls on `[T; N]`, `Box<[T]>`, and `Vec<T>`. To ensure we're not casting from a slice of a type containing uninitialized memory to `[u8]`, this bounds all such `CtEq` impls on a newly introduced `unsafe trait NoUninit`, which is currently not exposed in the public API except through the bounds. The trait has been impl'd for all of the types we impl other traits for in this crate.
1 parent 2a8b0b1 commit 62bb328

5 files changed

Lines changed: 171 additions & 26 deletions

File tree

ctutils/src/byteutils.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//! Utilities for leveraging the optimized implementations of `Cmov`/`CmovEq` for types whose size
2+
//! is 1-byte.
3+
4+
use crate::{Choice, traits::no_uninit::NoUninit};
5+
use cmov::CmovEq;
6+
use core::slice;
7+
8+
/// Perform constant-time equality comparison on slices of 1-byte sized types using the optimized
9+
/// implementation of `CmovEq` for byte slices.
10+
pub(crate) fn ct_eq<T: NoUninit>(a: &[T], b: &[T]) -> Choice {
11+
assert_eq!(
12+
size_of::<T>(),
13+
1,
14+
"this function is intended for 1-byte sized types"
15+
);
16+
17+
// SAFETY:
18+
// - We asserted above that `size_of::<T>() == size_of::<u8>() == 1`.
19+
// - The `NoUninit` bound ensures the type does not contain uninitialized memory.
20+
// - We don't need to worry about alignment because all types are 1-byte.
21+
// - 1-byte is too small to contain a pointer/reference.
22+
// - We source the slice length directly from the other valid slice.
23+
#[allow(unsafe_code)]
24+
let (a, b) = unsafe {
25+
(
26+
slice::from_raw_parts(a.as_ptr() as *const u8, a.len()),
27+
slice::from_raw_parts(b.as_ptr() as *const u8, b.len()),
28+
)
29+
};
30+
31+
let mut ret = Choice::FALSE;
32+
a.cmoveq(b, 1, &mut ret.0);
33+
ret
34+
}
35+
36+
#[cfg(test)]
37+
mod tests {
38+
use core::num::{NonZeroI8, NonZeroU8};
39+
40+
macro_rules! ct_eq_test {
41+
($name:ident, $a:expr, $b:expr) => {
42+
#[test]
43+
fn $name() {
44+
let x = $a;
45+
let y = $b;
46+
47+
let a = [x, x, x];
48+
let b = [x, x, y];
49+
let c = [x, y, y];
50+
let d = [y, y, y];
51+
52+
assert!(super::ct_eq(&a, &a).to_bool());
53+
assert!(super::ct_eq(&b, &b).to_bool());
54+
assert!(super::ct_eq(&c, &c).to_bool());
55+
assert!(super::ct_eq(&d, &d).to_bool());
56+
57+
for rhs in &[b, c, d] {
58+
assert!(!super::ct_eq(&a, rhs).to_bool());
59+
}
60+
}
61+
};
62+
}
63+
64+
ct_eq_test!(i8_ct_eq, 1i8, 2i8);
65+
ct_eq_test!(u8_ct_eq, 1u8, 2u8);
66+
ct_eq_test!(
67+
non_zero_i8_ct_eq,
68+
NonZeroI8::new(1i8).unwrap(),
69+
NonZeroI8::new(2i8).unwrap()
70+
);
71+
ct_eq_test!(
72+
non_zero_u8_ct_eq,
73+
NonZeroU8::new(1u8).unwrap(),
74+
NonZeroU8::new(2u8).unwrap()
75+
);
76+
}

ctutils/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
extern crate alloc;
9393

9494
mod bytes;
95+
mod byteutils;
9596
mod choice;
9697
mod ct_option;
9798
mod traits;

ctutils/src/traits.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ pub(crate) mod ct_lookup;
1111
pub(crate) mod ct_lt;
1212
pub(crate) mod ct_neg;
1313
pub(crate) mod ct_select;
14+
pub(crate) mod no_uninit;

ctutils/src/traits/ct_eq.rs

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::Choice;
1+
use crate::{Choice, byteutils, traits::no_uninit::NoUninit};
22
use cmov::CmovEq;
33
use core::{
44
cmp,
@@ -109,37 +109,43 @@ impl CtEq for cmp::Ordering {
109109
}
110110
}
111111

112-
impl<T: CtEq> CtEq for [T] {
112+
impl<T> CtEq for [T]
113+
where
114+
T: CtEq + NoUninit,
115+
{
113116
#[inline]
114117
fn ct_eq(&self, other: &[T]) -> Choice {
115-
const {
116-
assert!(
117-
size_of::<T>() != 1,
118-
"use `BytesCtEq::bytes_ct_eq` when working with byte-sized values"
119-
);
118+
/// Iterate over every element in `a_slice` and `b_slice` comparing elements.
119+
fn ct_eq_inner<T: CtEq>(a_slice: &[T], b_slice: &[T]) -> Choice {
120+
let mut ret = a_slice.len().ct_eq(&b_slice.len());
121+
for (a, b) in a_slice.iter().zip(b_slice.iter()) {
122+
ret &= a.ct_eq(b);
123+
}
124+
ret
120125
}
121126

122-
let mut ret = self.len().ct_eq(&other.len());
123-
for (a, b) in self.iter().zip(other.iter()) {
124-
ret &= a.ct_eq(b);
125-
}
126-
ret
127-
}
127+
if const { size_of::<T>() == 1 } {
128+
let ret = byteutils::ct_eq(self, other);
128129

129-
#[inline]
130-
fn ct_ne(&self, other: &[T]) -> Choice {
131-
const {
132-
assert!(
133-
size_of::<T>() != 1,
134-
"use `BytesCtEq::bytes_ct_ne` when working with byte-sized values"
130+
// Double-check the result we get from `byteutils::ct_eq` is the same one we would've
131+
// gotten had we invoked `CtEq` on each element (i.e. the slow path)
132+
debug_assert_eq!(
133+
ret.to_bool(),
134+
ct_eq_inner(self, other).to_bool(),
135+
"mismatch between fast and slow ct_eq implementations"
135136
);
136-
}
137137

138-
!self.ct_eq(other)
138+
ret
139+
} else {
140+
ct_eq_inner(self, other)
141+
}
139142
}
140143
}
141144

142-
impl<T: CtEq, const N: usize> CtEq for [T; N] {
145+
impl<T, const N: usize> CtEq for [T; N]
146+
where
147+
T: CtEq + NoUninit,
148+
{
143149
#[inline]
144150
fn ct_eq(&self, other: &[T; N]) -> Choice {
145151
self.as_slice().ct_eq(other.as_slice())
@@ -161,7 +167,7 @@ where
161167
#[cfg(feature = "alloc")]
162168
impl<T> CtEq for Box<[T]>
163169
where
164-
T: CtEq,
170+
T: CtEq + NoUninit,
165171
{
166172
#[inline]
167173
#[track_caller]
@@ -173,7 +179,7 @@ where
173179
#[cfg(feature = "alloc")]
174180
impl<T> CtEq<[T]> for Box<[T]>
175181
where
176-
T: CtEq,
182+
T: CtEq + NoUninit,
177183
{
178184
#[inline]
179185
#[track_caller]
@@ -185,7 +191,7 @@ where
185191
#[cfg(feature = "alloc")]
186192
impl<T> CtEq for Vec<T>
187193
where
188-
T: CtEq,
194+
T: CtEq + NoUninit,
189195
{
190196
#[inline]
191197
#[track_caller]
@@ -197,7 +203,7 @@ where
197203
#[cfg(feature = "alloc")]
198204
impl<T> CtEq<[T]> for Vec<T>
199205
where
200-
T: CtEq,
206+
T: CtEq + NoUninit,
201207
{
202208
#[inline]
203209
#[track_caller]

ctutils/src/traits/no_uninit.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#![allow(unsafe_code)]
2+
3+
use crate::{Choice, CtOption};
4+
use core::{
5+
cmp,
6+
num::{
7+
NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI128, NonZeroU8, NonZeroU16,
8+
NonZeroU32, NonZeroU64, NonZeroU128,
9+
},
10+
};
11+
12+
#[cfg(feature = "alloc")]
13+
use alloc::{boxed::Box, vec::Vec};
14+
15+
/// Marker trait for types which do not contain uninitialized memory.
16+
pub unsafe trait NoUninit {}
17+
18+
// Impl `NoUninit` for the given type
19+
macro_rules! impl_no_uninit {
20+
( $($ty:ty),+ ) => {
21+
$(
22+
unsafe impl NoUninit for $ty {}
23+
)+
24+
};
25+
}
26+
27+
impl_no_uninit!(
28+
i8,
29+
i16,
30+
i32,
31+
i64,
32+
i128,
33+
isize,
34+
u8,
35+
u16,
36+
u32,
37+
u64,
38+
u128,
39+
usize,
40+
NonZeroI8,
41+
NonZeroI16,
42+
NonZeroI32,
43+
NonZeroI64,
44+
NonZeroI128,
45+
NonZeroU8,
46+
NonZeroU16,
47+
NonZeroU32,
48+
NonZeroU64,
49+
NonZeroU128
50+
);
51+
52+
unsafe impl NoUninit for Choice {}
53+
unsafe impl NoUninit for cmp::Ordering {}
54+
unsafe impl<T: NoUninit> NoUninit for CtOption<T> {}
55+
unsafe impl<T: NoUninit> NoUninit for [T] {}
56+
unsafe impl<T: NoUninit, const N: usize> NoUninit for [T; N] {}
57+
58+
#[cfg(feature = "alloc")]
59+
unsafe impl<T: NoUninit> NoUninit for Box<T> {}
60+
#[cfg(feature = "alloc")]
61+
unsafe impl<T: NoUninit> NoUninit for Vec<T> {}

0 commit comments

Comments
 (0)