Skip to content

Commit 28a092f

Browse files
committed
fix(permutation): make bitwise_permute constant-time
The previous `bitwise_permute` used `BitArray::get_unchecked(secret_index)` to read bits at secret-derived positions. Even though all sized bit arrays here fit in a single cache line (≤16 bytes), the secret-indexed load is still a microarchitectural hazard — port pressure and other intra-cache-line side channels can leak the index on some CPUs. Reimplement as: unpack to one byte per bit (MSB-first), permute through the now-constant-time `permute_array` primitive, repack. This routes the bit-level operation through the same scan-and-mask machinery that already protects element-wise permutation. Secret intermediates (`bytes`, `bits`, `permuted`) are wrapped in `Zeroizing<_>` for unwind safety, matching the pattern in elementwise. Drops the direct `bitvec` usage in this file; the `BitArray` machinery is no longer needed. Bench results extended to cover bitwise_permute. Overhead: 28×–199× across N=8..128, in line with the elementwise result.
1 parent 9bdec01 commit 28a092f

2 files changed

Lines changed: 91 additions & 16 deletions

File tree

packages/permutation/examples/bench_permute.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
use std::hint::black_box;
66
use std::time::Instant;
77
use subtle::{ConditionallySelectable, ConstantTimeEq};
8-
use vitaminc_permutation::{PermutationKey, Permute};
8+
use vitaminc_permutation::{BitwisePermute, PermutationKey, Permute};
99
use vitaminc_random::{Generatable, SafeRand, SeedableRng};
1010

1111
fn permute_old<const N: usize>(key_bytes: &[u8; N], input: [u8; N]) -> [u8; N] {
@@ -70,11 +70,81 @@ where
7070
);
7171
}
7272

73+
fn bitwise_old<const N: usize, const ARR: usize>(
74+
key_bytes: &[u8; N],
75+
bytes: [u8; ARR],
76+
) -> [u8; ARR] {
77+
// Reference variable-time impl: secret-indexed bit reads.
78+
let mut out = [0u8; ARR];
79+
for i in 0..N {
80+
let k = key_bytes[i] as usize;
81+
let bit = (bytes[k / 8] >> (7 - (k % 8))) & 1;
82+
out[i / 8] |= bit << (7 - (i % 8));
83+
}
84+
out
85+
}
86+
87+
fn bench_bitwise<const N: usize, const ARR: usize, T>(label: &str, iters: u32, sample: T)
88+
where
89+
PermutationKey<N>: BitwisePermute<N, T> + Generatable + Permute<[u8; N]>,
90+
T: Copy,
91+
[u8; N]: AsRef<[u8]>,
92+
{
93+
let mut rng = SafeRand::from_seed([42; 32]);
94+
let key: PermutationKey<N> = PermutationKey::random(&mut rng).unwrap();
95+
let identity: [u8; N] = std::array::from_fn(|i| i as u8);
96+
let key_bytes: [u8; N] = key.permute(identity);
97+
let mut bytes = [0u8; ARR];
98+
for (i, b) in bytes.iter_mut().enumerate() {
99+
*b = if i % 2 == 0 { 0xAA } else { 0x55 };
100+
}
101+
102+
for _ in 0..1000 {
103+
black_box(bitwise_old::<N, ARR>(
104+
black_box(&key_bytes),
105+
black_box(bytes),
106+
));
107+
black_box(key.bitwise_permute(black_box(sample)));
108+
}
109+
110+
let t0 = Instant::now();
111+
for _ in 0..iters {
112+
black_box(bitwise_old::<N, ARR>(
113+
black_box(&key_bytes),
114+
black_box(bytes),
115+
));
116+
}
117+
let old_ns = t0.elapsed().as_nanos() as f64 / iters as f64;
118+
119+
let t0 = Instant::now();
120+
for _ in 0..iters {
121+
black_box(key.bitwise_permute(black_box(sample)));
122+
}
123+
let new_ns = t0.elapsed().as_nanos() as f64 / iters as f64;
124+
125+
println!(
126+
"{label:>12}: old(idx) {old_ns:>7.1} ns | new(scan) {new_ns:>7.1} ns | overhead {:>5.2}x",
127+
new_ns / old_ns
128+
);
129+
}
130+
73131
fn main() {
74132
let iters = 500_000;
133+
println!("== elementwise permute ==");
75134
bench::<8>("N=8", iters);
76135
bench::<16>("N=16", iters);
77136
bench::<32>("N=32", iters);
78137
bench::<64>("N=64", iters);
79138
bench::<128>("N=128", iters);
139+
140+
println!("\n== bitwise permute ==");
141+
bench_bitwise::<8, 1, u8>("N=8", iters, 0xAAu8);
142+
bench_bitwise::<16, 2, u16>("N=16", iters, 0xAAAAu16);
143+
bench_bitwise::<32, 4, u32>("N=32", iters, 0xAAAA_AAAAu32);
144+
bench_bitwise::<64, 8, u64>("N=64", iters, 0xAAAA_AAAA_AAAA_AAAAu64);
145+
bench_bitwise::<128, 16, u128>(
146+
"N=128",
147+
iters,
148+
0xAAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAAu128,
149+
);
80150
}

packages/permutation/src/bitwise.rs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
use crate::PermutationKey;
2-
use bitvec::{array::BitArray, order::Msb0};
1+
use crate::{elementwise::permute_array, PermutationKey};
32
use std::num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8};
4-
use zeroize::Zeroize;
3+
use zeroize::{Zeroize, Zeroizing};
54

65
// TODO: Make this a private trait
76
// FIXME: This trait is backwards - self should be T and the argument should be a key
@@ -13,19 +12,25 @@ macro_rules! impl_bitwise_permutable {
1312
($N:literal, $int_type:ty, $array_size:expr) => {
1413
impl BitwisePermute<$N, $int_type> for PermutationKey<$N> {
1514
fn bitwise_permute(&self, mut input: $int_type) -> $int_type {
16-
let bytes = input.to_be_bytes();
17-
let arr: BitArray<[u8; $array_size], Msb0> = BitArray::new(bytes);
18-
let out: BitArray<[u8; $array_size], Msb0> = self.iter().enumerate().fold(
19-
BitArray::new([0; $array_size]),
20-
|mut out, (i, k)| {
21-
out.set(i, *unsafe { arr.get_unchecked(k) });
22-
out
23-
},
24-
);
25-
15+
// Unpack the input into one byte per bit (MSB-first), permute
16+
// the bit-vector through the constant-time `permute_array`
17+
// primitive, then re-pack. The previous implementation used
18+
// `bitvec::get_unchecked(secret_index)` which performs a
19+
// secret-dependent bit load; even though the bit array fits in
20+
// a single cache line, this defends against intra-cache-line
21+
// microarchitectural leaks (port pressure, etc.).
22+
let bytes = Zeroizing::new(input.to_be_bytes());
23+
let mut bits: Zeroizing<[u8; $N]> = Zeroizing::new([0u8; $N]);
24+
for j in 0..$N {
25+
bits[j] = (bytes[j / 8] >> (7 - (j % 8))) & 1;
26+
}
27+
let permuted = Zeroizing::new(permute_array(self, *bits));
28+
let mut out = [0u8; $array_size];
29+
for j in 0..$N {
30+
out[j / 8] |= (permuted[j] & 1) << (7 - (j % 8));
31+
}
2632
input.zeroize();
27-
28-
<$int_type>::from_be_bytes(out.into_inner())
33+
<$int_type>::from_be_bytes(out)
2934
}
3035
}
3136
};

0 commit comments

Comments
 (0)