Skip to content

Commit 46a43fb

Browse files
authored
cmov: impl CmovEq for [u8; N] (#1353)
Similar to the `Cmov` implementation added in #1350, this adds a `CmovEq` implementation for `[u8; N]` which breaks the byte slice up into word-sized chunks, interprets them as a native word integer type (either `u32` or `u64` ala #1350), then calls into the `CmovEq` for that type, using a similar strategy for slices where we just iterate over the array calling `CmovEq::cmovne` on the elements, and if any are non-equal the condition will be moved. This also simplifes the implementation of `Cmov` for `[u8; N]` by vendoring the MSRV 1.88 core functions `[T]::as_chunks(_mut)` as `utils::slice_as_chunks(_mut)`, and using those to write a simpler implementation which can also benefit from receiving a word-sized byte array instead of a slice, which can be passed directly to `Word::from_ne_bytes`, avoiding the previous usage of potentially panicking slice conversions. As noted in the TODO, we can get rid of the vendored functions when we bump MSRV to 1.88.
1 parent f2a5213 commit 46a43fb

5 files changed

Lines changed: 244 additions & 125 deletions

File tree

cmov/src/array.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use crate::{
2+
Cmov, CmovEq, Condition,
3+
utils::{slice_as_chunks, slice_as_chunks_mut},
4+
};
5+
6+
// Uses 64-bit words on 64-bit targets, 32-bit everywhere else
7+
#[cfg(not(target_pointer_width = "64"))]
8+
type Word = u32;
9+
#[cfg(target_pointer_width = "64")]
10+
type Word = u64;
11+
const WORD_SIZE: usize = size_of::<Word>();
12+
13+
/// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
14+
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
15+
///
16+
/// With compile-time knowledge of `N`, the compiler should also be able to unroll the loops in
17+
/// cases where efficiency would benefit, reducing the implementation to a sequence of word-sized
18+
/// [`Cmov`] ops (and if `N` isn't word-aligned, followed by a series of 1-byte ops).
19+
impl<const N: usize> Cmov for [u8; N] {
20+
#[inline]
21+
fn cmovnz(&mut self, value: &Self, condition: Condition) {
22+
let (self_chunks, self_remainder) = slice_as_chunks_mut::<u8, WORD_SIZE>(self);
23+
let (value_chunks, value_remainder) = slice_as_chunks::<u8, WORD_SIZE>(value);
24+
25+
for (self_chunk, value_chunk) in self_chunks.iter_mut().zip(value_chunks.iter()) {
26+
let mut a = Word::from_ne_bytes(*self_chunk);
27+
let b = Word::from_ne_bytes(*value_chunk);
28+
a.cmovnz(&b, condition);
29+
self_chunk.copy_from_slice(&a.to_ne_bytes());
30+
}
31+
32+
// Process the remainder a byte-at-a-time.
33+
for (a, b) in self_remainder.iter_mut().zip(value_remainder.iter()) {
34+
a.cmovnz(b, condition);
35+
}
36+
}
37+
}
38+
39+
impl<const N: usize> CmovEq for [u8; N] {
40+
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
41+
let (self_chunks, self_remainder) = slice_as_chunks::<u8, WORD_SIZE>(self);
42+
let (rhs_chunks, rhs_remainder) = slice_as_chunks::<u8, WORD_SIZE>(rhs);
43+
44+
for (self_chunk, rhs_chunk) in self_chunks.iter().zip(rhs_chunks.iter()) {
45+
let a = Word::from_ne_bytes(*self_chunk);
46+
let b = Word::from_ne_bytes(*rhs_chunk);
47+
a.cmovne(&b, input, output);
48+
}
49+
50+
// Process the remainder a byte-at-a-time.
51+
for (a, b) in self_remainder.iter().zip(rhs_remainder.iter()) {
52+
a.cmovne(b, input, output);
53+
}
54+
}
55+
}

cmov/src/lib.rs

Lines changed: 14 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@
2828
)]
2929

3030
#[macro_use]
31-
mod macros;
31+
mod utils;
3232

3333
#[cfg(not(miri))]
3434
#[cfg(target_arch = "aarch64")]
3535
mod aarch64;
36+
mod array;
3637
#[cfg(any(
3738
not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")),
3839
miri
@@ -49,16 +50,12 @@ pub type Condition = u8;
4950
pub trait Cmov {
5051
/// Move if non-zero.
5152
///
52-
/// Uses a `test` instruction to check if the given `condition` value is
53-
/// equal to zero, conditionally moves `value` to `self` when `condition` is
54-
/// not equal to zero.
53+
/// Moves `value` to `self` in constant-time if `condition` is non-zero.
5554
fn cmovnz(&mut self, value: &Self, condition: Condition);
5655

5756
/// Move if zero.
5857
///
59-
/// Uses a `cmp` instruction to check if the given `condition` value is
60-
/// equal to zero, and if so, conditionally moves `value` to `self`
61-
/// when `condition` is equal to zero.
58+
/// Moves `value` to `self` in constant-time if `condition` is equal to zero.
6259
fn cmovz(&mut self, value: &Self, condition: Condition) {
6360
let nz = masknz!(condition: Condition);
6461
self.cmovnz(value, !nz)
@@ -67,21 +64,23 @@ pub trait Cmov {
6764

6865
/// Conditional move with equality comparison
6966
pub trait CmovEq {
70-
/// Move if both inputs are equal.
71-
///
72-
/// Uses a `xor` instruction to compare the two values, and
73-
/// conditionally moves `input` to `output` when they are equal.
74-
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition);
75-
7667
/// Move if both inputs are not equal.
7768
///
78-
/// Uses a `xor` instruction to compare the two values, and
79-
/// conditionally moves `input` to `output` when they are not equal.
69+
/// Moves `input` to `output` in constant-time if `self` and `rhs` are NOT equal.
8070
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
8171
let mut tmp = 1u8;
8272
self.cmoveq(rhs, 0u8, &mut tmp);
8373
tmp.cmoveq(&1u8, input, output);
8474
}
75+
76+
/// Move if both inputs are equal.
77+
///
78+
/// Moves `input` to `output` in constant-time if `self` and `rhs` are equal.
79+
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
80+
let mut tmp = 1u8;
81+
self.cmovne(rhs, 0u8, &mut tmp);
82+
tmp.cmoveq(&1, input, output);
83+
}
8584
}
8685

8786
impl Cmov for u8 {
@@ -200,58 +199,7 @@ macro_rules! impl_cmov_traits_for_signed_ints {
200199

201200
impl_cmov_traits_for_signed_ints!(i8 => u8, i16 => u16, i32 => u32, i64 => u64, i128 => u128);
202201

203-
/// Optimized implementation for byte arrays which coalesces them into word-sized chunks first,
204-
/// then performs [`Cmov`] at the word-level to cut down on the total number of instructions.
205-
///
206-
/// With compile-time knowledge of `N`, the compiler should also be able to unroll the loops in
207-
/// cases where efficiency would benefit, reducing the implementation to a sequence of word-sized
208-
/// [`Cmov`] ops (and if `N` isn't word-aligned, followed by a series of 1-byte ops).
209-
impl<const N: usize> Cmov for [u8; N] {
210-
#[inline]
211-
fn cmovnz(&mut self, value: &Self, condition: Condition) {
212-
// Uses 64-bit words on 64-bit targets, 32-bit everywhere else
213-
#[cfg(not(target_pointer_width = "64"))]
214-
type Chunk = u32;
215-
#[cfg(target_pointer_width = "64")]
216-
type Chunk = u64;
217-
const CHUNK_SIZE: usize = size_of::<Chunk>();
218-
219-
// Load a chunk from a byte slice
220-
// TODO(tarcieri): use `array_chunks` when stable (rust-lang/rust##100450)
221-
#[inline]
222-
fn load_chunk(slice: &[u8]) -> Chunk {
223-
Chunk::from_ne_bytes(slice.try_into().expect("should be the right size"))
224-
}
225-
226-
let mut self_chunks = self.chunks_exact_mut(CHUNK_SIZE);
227-
let mut value_chunks = value.chunks_exact(CHUNK_SIZE);
228-
229-
// Process as much input as we can a `Chunk`-at-a-time.
230-
for (self_chunk, value_chunk) in self_chunks.by_ref().zip(value_chunks.by_ref()) {
231-
let mut a = load_chunk(self_chunk);
232-
let b = load_chunk(value_chunk);
233-
a.cmovnz(&b, condition);
234-
self_chunk.copy_from_slice(&a.to_ne_bytes());
235-
}
236-
237-
// Process the remainder a byte-at-a-time.
238-
for (a, b) in self_chunks
239-
.into_remainder()
240-
.iter_mut()
241-
.zip(value_chunks.remainder().iter())
242-
{
243-
a.cmovnz(b, condition);
244-
}
245-
}
246-
}
247-
248202
impl<T: CmovEq> CmovEq for [T] {
249-
fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) {
250-
let mut tmp = 1u8;
251-
self.cmovne(rhs, 0u8, &mut tmp);
252-
tmp.cmoveq(&1, input, output);
253-
}
254-
255203
fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) {
256204
// Short-circuit the comparison if the slices are of different lengths, and set the output
257205
// condition to the input condition.

cmov/src/macros.rs

Lines changed: 0 additions & 58 deletions
This file was deleted.

cmov/src/utils.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//! Macro definitions.
2+
3+
use core::slice;
4+
5+
/// Generates a mask the width of the input type if the input value is non-zero.
6+
///
7+
/// Uses `core::hint::black_box` to coerce our desired codegen based on real-world observations
8+
/// of the assembly generated by Rust/LLVM.
9+
///
10+
/// See also:
11+
/// - CVE-2026-23519
12+
/// - RustCrypto/utils#1332
13+
macro_rules! masknz {
14+
($value:tt : $int:ident) => {{
15+
let mut value: $int = $value;
16+
value |= value.wrapping_neg(); // has MSB `1` if non-zero, `0` if zero
17+
18+
// use `black_box` to obscure we're computing a 1-bit value
19+
core::hint::black_box(
20+
value >> ($int::BITS - 1), // Extract MSB
21+
)
22+
.wrapping_neg() // Generate $int::MAX mask if `black_box` outputs `1`
23+
}};
24+
}
25+
26+
/// Rust core `[T]::as_chunks` vendored because of its 1.88 MSRV.
27+
/// TODO(tarcieri): use upstream function when we bump MSRV
28+
#[inline]
29+
#[track_caller]
30+
#[must_use]
31+
#[allow(clippy::integer_division_remainder_used)]
32+
pub(crate) fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
33+
assert!(N != 0, "chunk size must be non-zero");
34+
let len_rounded_down = slice.len() / N * N;
35+
// SAFETY: The rounded-down value is always the same or smaller than the
36+
// original length, and thus must be in-bounds of the slice.
37+
let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
38+
// SAFETY: We already panicked for zero, and ensured by construction
39+
// that the length of the subslice is a multiple of N.
40+
let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
41+
(array_slice, remainder)
42+
}
43+
44+
/// Rust core `[T]::as_chunks_mut` vendored because of its 1.88 MSRV.
45+
/// TODO(tarcieri): use upstream function when we bump MSRV
46+
#[inline]
47+
#[track_caller]
48+
#[must_use]
49+
#[allow(clippy::integer_division_remainder_used)]
50+
pub(crate) fn slice_as_chunks_mut<T, const N: usize>(slice: &mut [T]) -> (&mut [[T; N]], &mut [T]) {
51+
assert!(N != 0, "chunk size must be non-zero");
52+
let len_rounded_down = slice.len() / N * N;
53+
// SAFETY: The rounded-down value is always the same or smaller than the
54+
// original length, and thus must be in-bounds of the slice.
55+
let (multiple_of_n, remainder) = unsafe { slice.split_at_mut_unchecked(len_rounded_down) };
56+
// SAFETY: We already panicked for zero, and ensured by construction
57+
// that the length of the subslice is a multiple of N.
58+
let array_slice = unsafe { slice_as_chunks_unchecked_mut(multiple_of_n) };
59+
(array_slice, remainder)
60+
}
61+
62+
/// Rust core `[T]::as_chunks_unchecked` vendored because of its 1.88 MSRV.
63+
/// TODO(tarcieri): use upstream function when we bump MSRV
64+
#[inline]
65+
#[must_use]
66+
#[track_caller]
67+
#[allow(clippy::integer_division_remainder_used)]
68+
unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
69+
// SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
70+
const { debug_assert!(N != 0) };
71+
debug_assert_eq!(slice.len() % N, 0);
72+
let new_len = slice.len() / N;
73+
74+
// SAFETY: We cast a slice of `new_len * N` elements into
75+
// a slice of `new_len` many `N` elements chunks.
76+
unsafe { slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
77+
}
78+
79+
/// Rust core `[T]::as_chunks_unchecked_mut` vendored because of its 1.88 MSRV.
80+
/// TODO(tarcieri): use upstream function when we bump MSRV
81+
#[inline]
82+
#[must_use]
83+
#[track_caller]
84+
#[allow(clippy::integer_division_remainder_used)]
85+
unsafe fn slice_as_chunks_unchecked_mut<T, const N: usize>(slice: &mut [T]) -> &mut [[T; N]] {
86+
// SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
87+
const { debug_assert!(N != 0) };
88+
debug_assert_eq!(slice.len() % N, 0);
89+
let new_len = slice.len() / N;
90+
91+
// SAFETY: We cast a slice of `new_len * N` elements into
92+
// a slice of `new_len` many `N` elements chunks.
93+
unsafe { slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), new_len) }
94+
}
95+
96+
#[cfg(test)]
97+
mod tests {
98+
// Spot check up to a given limit
99+
const TEST_LIMIT: u8 = 128;
100+
101+
macro_rules! masknz_test {
102+
( $($name:ident : $int:ident),+ ) => {
103+
$(
104+
#[test]
105+
fn $name() {
106+
assert_eq!(masknz!(0: $int), 0);
107+
108+
// Test lower values
109+
for i in 1..=$int::from(TEST_LIMIT) {
110+
assert_eq!(masknz!(i: $int), $int::MAX);
111+
}
112+
113+
// Test upper values
114+
for i in ($int::MAX - $int::from(TEST_LIMIT))..=$int::MAX {
115+
assert_eq!(masknz!(i: $int), $int::MAX);
116+
}
117+
}
118+
)+
119+
}
120+
}
121+
122+
// Ensure the macro works with any types we might use it with (we only use u8, u32, and u64)
123+
masknz_test!(
124+
masknz_u8: u8,
125+
masknz_u16: u16,
126+
masknz_u32: u32,
127+
masknz_u64: u64,
128+
masknz_u128: u128
129+
);
130+
}

0 commit comments

Comments
 (0)