Skip to content

Commit 7026afd

Browse files
committed
feat: add auxialiary, compress and encode functions
1 parent 4891d87 commit 7026afd

3 files changed

Lines changed: 90 additions & 44 deletions

File tree

src/kem/kyber/auxiliary.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
//! Contains auxiliary cryptographic functions for Kyber KEM.
2+
//! - [`prf`] - Pseudorandom function
3+
//! - [`h`], [`g`] - Hash function
4+
//! - [`Xof`] - Extendable output function
15
use sha3::{
26
digest::{ExtendableOutput, Update, XofReader},
37
Digest, Shake128, Shake256,
48
};
59

6-
pub fn prf<const eta: usize>(s: [u8; 32], b: u8) -> [u8; 64 * eta] {
7-
// concat s and b
10+
pub fn prf<const ETA: usize>(s: &[u8], b: u8) -> [u8; 64 * ETA] {
11+
assert!(s.len() == 32);
12+
813
let mut hasher = Shake256::default();
914
hasher.update(&s);
1015
hasher.update(&[b]);
11-
let mut res = [0u8; 64 * eta];
16+
let mut res = [0u8; 64 * ETA];
1217
XofReader::read(&mut hasher.finalize_xof(), &mut res);
1318
res
1419
}

src/kem/kyber/compress.rs

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::MlKemField;
1+
use super::{MlKemField, PolyVec};
22
use crate::{
33
algebra::Finite,
44
polynomial::{Monomial, Polynomial},
@@ -7,57 +7,69 @@ use crate::{
77
/// Compresses a number x to a number in the range [0, 2^d) using the formula round((2^d / q) * x)
88
/// mod 2^d.
99
/// round(a / b) = floor((a + b/2) / b)
10-
pub fn compress_fieldelement<const D: usize>(x: &MlKemField) -> MlKemField {
10+
pub fn compress_fieldelement<const d: usize>(x: &MlKemField) -> MlKemField {
1111
// TODO: Implement using barrett reduction
1212
let q_half = (MlKemField::ORDER + 1) >> 1;
13-
MlKemField::new((((x.value << D) + q_half) / MlKemField::ORDER) % (1 << D))
13+
MlKemField::new((((x.value << d) + q_half) / MlKemField::ORDER) % (1 << d))
1414
}
1515

1616
/// Decompresses a number y to a number in the range [0, q) using the formula round((q / 2^d)) * y.
17-
pub fn decompress_fieldelement<const D: usize>(y: &MlKemField) -> MlKemField {
18-
let d_pow_half = 1 << (D - 1);
17+
pub fn decompress_fieldelement<const d: usize>(y: &MlKemField) -> MlKemField {
18+
let d_pow_half = 1 << (d - 1);
1919
let quotient = MlKemField::ORDER * y.value + d_pow_half;
20-
MlKemField::new(quotient >> D)
20+
MlKemField::new(quotient >> d)
2121
}
2222

23-
pub fn poly_compress<const D: usize>(
23+
pub fn poly_compress<const D: usize, const d: usize>(
2424
poly: &Polynomial<Monomial, MlKemField, D>,
25-
) -> [MlKemField; D] {
25+
) -> Polynomial<Monomial, MlKemField, D> {
2626
// TODO: remove unwrap
27-
poly
27+
let coeffs = poly
2828
.coefficients
2929
.iter()
30-
.map(compress_fieldelement::<8>)
30+
.map(compress_fieldelement::<d>)
3131
.collect::<Vec<MlKemField>>()
3232
.try_into()
33-
.unwrap()
33+
.unwrap();
34+
35+
Polynomial::<Monomial, MlKemField, D>::new(coeffs)
3436
}
3537

36-
pub fn poly_decompress<const D: usize>(
38+
pub fn poly_decompress<const D: usize, const d: usize>(
3739
poly: &[MlKemField; D],
3840
) -> Polynomial<Monomial, MlKemField, D> {
3941
let mut coefficients = [MlKemField::default(); D];
4042
for (i, x) in poly.iter().enumerate() {
41-
coefficients[i] = decompress_fieldelement::<8>(x);
43+
coefficients[i] = decompress_fieldelement::<d>(x);
4244
}
4345
Polynomial::<Monomial, MlKemField, D>::new(coefficients)
4446
}
4547

46-
// pub fn polyvec_compress<const D: usize, const K: usize>(
47-
// poly_vec: &PolyVec<Monomial, D, K>,
48-
// ) -> [[MlKemField; D]; K] {
49-
// let mut res = [[MlKemField::default(); D]; K];
48+
pub fn polyvec_compress<const D: usize, const d: usize, const K: usize>(
49+
poly_vec: &PolyVec<Monomial, D, K>,
50+
) -> PolyVec<Monomial, D, K> {
51+
let mut res = Vec::with_capacity(K);
5052

51-
// for (i, poly) in poly_vec.vec.iter().enumerate() {
52-
// res[i] = poly_compress(poly);
53-
// }
53+
for poly in poly_vec.vec.iter() {
54+
res.push(poly_compress::<D, d>(poly));
55+
}
5456

55-
// res
56-
// }
57+
let res = res.try_into().unwrap();
58+
PolyVec::new(res)
59+
}
5760

58-
// pub fn polyvec_decompress<const D: usize, const K: usize>(
61+
pub fn polyvec_decompress<const D: usize, const d: usize, const K: usize>(
62+
poly_vec: &PolyVec<Monomial, D, K>,
63+
) -> PolyVec<Monomial, D, K> {
64+
let mut res = Vec::with_capacity(K);
5965

60-
// )
66+
for poly in poly_vec.vec.iter() {
67+
res.push(poly_decompress::<D, d>(&poly.coefficients));
68+
}
69+
70+
let res = res.try_into().unwrap();
71+
PolyVec::new(res)
72+
}
6173

6274
#[test]
6375
fn test_compress_decompress() {

src/kem/kyber/encode.rs

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,42 @@
1-
use super::MlKemField;
2-
use crate::algebra::field::Field;
1+
use super::{MlKemField, PolyVec};
2+
use crate::{
3+
algebra::field::Field,
4+
polynomial::{Basis, Polynomial},
5+
};
36

4-
/// Encodes a field element into a byte array where each field element is represented by D bits.
7+
/// Encodes a field element into a byte array where each field element is represented by d bits.
58
/// Converts the field element into a binary representation and then packs the bits into bytes.
6-
fn byte_encode<const D: usize>(f: [MlKemField; 256]) -> [u8; 32 * D]
7-
where [(); 256 * D]: {
8-
let mut encoded_bits = [0u8; 256 * D];
9+
pub fn byte_encode<const d: usize, const D: usize>(f: &[MlKemField; D]) -> Vec<u8> {
10+
let mut encoded_bits = Vec::with_capacity(D * d);
911

1012
for (i, x) in f.iter().enumerate() {
1113
let mut val = x.value;
12-
for j in 0..D {
13-
encoded_bits[i * D + j] = (val & 1) as u8;
14+
for j in 0..d {
15+
encoded_bits[i * d + j] = (val & 1) as u8;
1416
val >>= 1;
1517
}
1618
}
1719

18-
let mut encoded_bytes = [0u8; 32 * D];
20+
let mut encoded_bytes = Vec::with_capacity(D / 8 * d);
1921
for (i, chunk) in encoded_bits.chunks(8).enumerate() {
2022
encoded_bytes[i] = chunk.iter().enumerate().fold(0, |acc, (j, &b)| acc | (b << j));
2123
}
2224

2325
encoded_bytes
2426
}
2527

28+
pub fn byte_encode_polyvec<B: Basis, const D: usize, const K: usize, const d: usize>(
29+
f: PolyVec<B, D, K>,
30+
) -> Vec<u8> {
31+
let mut encoded_bytes = Vec::with_capacity(D / 8 * d * K);
32+
for (i, poly) in f.vec.iter().enumerate() {
33+
let encoded = byte_encode::<d, D>(&poly.coefficients);
34+
encoded_bytes[i * D / 8 * d..(i + 1) * D / 8 * d].copy_from_slice(&encoded);
35+
}
36+
37+
encoded_bytes
38+
}
39+
2640
/// Encodes a field element into a byte array where each field element is represented by D bits.
2741
/// Converts the field element into a binary representation and then packs the bits into bytes.
2842
fn byte_encode_optimized<const D: usize>(f: [MlKemField; 256]) -> [u8; 32 * D]
@@ -51,27 +65,42 @@ where [(); 256 * D]: {
5165

5266
/// Decodes a byte array into a field element where each field element is represented by D bits.
5367
/// Unpacks the bytes into bits and then converts the bits into a field element.
54-
fn byte_decode<const D: usize>(encoded_bytes: [u8; 32 * D]) -> [MlKemField; 256]
55-
where [(); 256 * D]: {
56-
let mut encoded_bits = [0u8; 256 * D];
68+
pub fn byte_decode<const d: usize, const D: usize>(encoded_bytes: &[u8]) -> [MlKemField; D] {
69+
let mut encoded_bits = Vec::with_capacity(256 * d);
5770
for (i, &byte) in encoded_bytes.iter().enumerate() {
5871
for j in 0..8 {
59-
encoded_bits[i * 8 + j] = (byte >> j) & 1;
72+
encoded_bits.push((byte >> j) & 1);
6073
}
6174
}
6275

63-
let mut f = [MlKemField::ZERO; 256];
76+
let mask: usize = (1 << d) - 1;
77+
let mut f = [MlKemField::ZERO; D];
6478
for (i, chunk) in encoded_bits.chunks(D).enumerate() {
6579
let mut val = 0;
6680
for (j, &bit) in chunk.iter().enumerate() {
67-
val |= (bit as usize) << j;
81+
val |= ((bit as usize) << j) & mask;
6882
}
6983
f[i].value = val;
7084
}
7185

7286
f
7387
}
7488

89+
pub fn byte_decode_polyvec<B: Basis, const D: usize, const K: usize, const d: usize>(
90+
encoded_bytes: &[u8],
91+
basis: B,
92+
) -> PolyVec<B, D, K> {
93+
let mut f = Vec::with_capacity(K);
94+
95+
for bytes in encoded_bytes.chunks(32 * d) {
96+
let coeffs = byte_decode::<d, D>(bytes.try_into().unwrap());
97+
f.push(Polynomial { coefficients: coeffs, basis: basis.clone() })
98+
}
99+
100+
let f = f.try_into().unwrap();
101+
PolyVec::new(f)
102+
}
103+
75104
#[cfg(test)]
76105
mod tests {
77106
use super::*;
@@ -88,15 +117,15 @@ mod tests {
88117
#[test]
89118
fn test_byte_encode() {
90119
let f = generate_test_data();
91-
let encoded = byte_encode::<8>(f);
120+
let encoded = byte_encode::<8, 256>(&f);
92121
let encoded_optimized = byte_encode_optimized::<8>(f);
93122
assert_eq!(encoded, encoded_optimized);
94123
}
95124

96125
#[bench]
97126
fn bench_byte_encode(b: &mut test::Bencher) {
98127
let f = generate_test_data();
99-
b.iter(|| byte_encode::<8>(f));
128+
b.iter(|| byte_encode::<8, 256>(&f));
100129
}
101130

102131
#[bench]

0 commit comments

Comments
 (0)