Skip to content

Commit 93ee2cb

Browse files
committed
feat: add sampling and ntt algorithms
1 parent 7026afd commit 93ee2cb

5 files changed

Lines changed: 56 additions & 15 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ exclude =["CHANGELOG.md", "src/tree/ConstructMerkleTree.gif"]
1212
rand ="0.8"
1313
itertools="0.13"
1414
hex ="0.4"
15-
sha3 = "0.10.8"
15+
sha3 ="0.10.8"
1616

1717
[dev-dependencies]
1818
rstest ="0.23"

src/kem/kyber/ntt.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::ops::Mul;
22

3-
use super::{MlKemField, Ntt};
3+
use super::{MlKemField, Ntt, PolyVec};
44
use crate::{
55
algebra::{field::Field, Finite},
66
polynomial::{Monomial, Polynomial},
@@ -71,21 +71,38 @@ impl<const D: usize> Polynomial<Monomial, MlKemField, D> {
7171
}
7272
}
7373

74-
impl Mul<Polynomial<Ntt, MlKemField, 256>> for Polynomial<Ntt, MlKemField, 256> {
75-
type Output = Self;
74+
impl<const D: usize, const K: usize> PolyVec<Monomial, D, K> {
75+
pub fn ntt(self) -> PolyVec<Ntt, D, K> {
76+
let ntt_vec = self.vec.iter().map(|poly| poly.ntt()).collect::<Vec<_>>().try_into().unwrap();
7677

77-
fn mul(self, rhs: Polynomial<Ntt, MlKemField, 256>) -> Self::Output {
78-
let mut res_coeffs = [MlKemField::ZERO; 256];
78+
PolyVec::<Ntt, D, K> { vec: ntt_vec }
79+
}
80+
}
81+
82+
impl<const D: usize, const K: usize> PolyVec<Ntt, D, K> {
83+
pub fn ntt_inv(self) -> PolyVec<Monomial, D, K> {
84+
let ntt_inv_vec =
85+
self.vec.iter().map(|poly| poly.ntt_inv()).collect::<Vec<_>>().try_into().unwrap();
86+
87+
PolyVec { vec: ntt_inv_vec }
88+
}
89+
}
90+
91+
impl<const D: usize> Mul<&Polynomial<Ntt, MlKemField, D>> for &Polynomial<Ntt, MlKemField, D> {
92+
type Output = Polynomial<Ntt, MlKemField, D>;
93+
94+
fn mul(self, rhs: &Polynomial<Ntt, MlKemField, D>) -> Self::Output {
95+
let mut res_coeffs = [MlKemField::ZERO; D];
7996

80-
for i in 0..128 {
97+
for i in 0..D >> 1 {
8198
let (a0, a1) = (self.coefficients[2 * i], self.coefficients[2 * i + 1]);
8299
let (b0, b1) = (rhs.coefficients[2 * i], rhs.coefficients[2 * i + 1]);
83100
let (c0, c1) = (a0 * b0 + GAMMA[i] * a1 * b1, a0 * b1 + a1 * b0);
84101
res_coeffs[2 * i] = c0;
85102
res_coeffs[2 * i + 1] = c1;
86103
}
87104

88-
Polynomial::<Ntt, MlKemField, 256> { coefficients: res_coeffs, basis: Ntt }
105+
Polynomial::<Ntt, MlKemField, D> { coefficients: res_coeffs, basis: Ntt }
89106
}
90107
}
91108

src/kem/kyber/sampling.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
use sha3::digest::XofReader;
2-
31
use super::{auxiliary::Xof, MlKemField};
42
use crate::algebra::{field::Field, Finite};
53

6-
pub fn sample_ntt(input: &[u8]) -> [MlKemField; 256] {
7-
assert!(input.len() == 34);
4+
pub fn sample_ntt(rho: &[u8], j: u8, i: u8) -> [MlKemField; 256] {
5+
assert!(rho.len() == 32);
6+
let mut input = [0u8; 34];
7+
input[..32].copy_from_slice(rho);
8+
input[32] = j;
9+
input[33] = i;
10+
811
let mut ntt = [MlKemField::ZERO; 256];
912

10-
let mut xof = Xof::init().absorb(input);
13+
let mut xof = Xof::init().absorb(&input);
1114
let mut j = 0;
1215
while j < 256 {
1316
let mut buf = [0u8; 3];

src/polynomial/arithmetic.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,27 @@ impl<F: FiniteField, const D: usize, const D2: usize> Add<Polynomial<Monomial, F
3434
}
3535
}
3636

37+
impl<F: FiniteField, const D: usize, const D2: usize> Add<&Polynomial<Monomial, F, D2>>
38+
for &Polynomial<Monomial, F, D>
39+
{
40+
type Output = Polynomial<Monomial, F, D>;
41+
42+
/// Implements addition of two polynomials by adding their coefficients.
43+
/// Note: degree of first operand > deg of second operand.
44+
fn add(self, rhs: &Polynomial<Monomial, F, D2>) -> Self::Output {
45+
let coefficients = self
46+
.coefficients
47+
.iter()
48+
.zip(rhs.coefficients.iter().chain(std::iter::repeat(&F::ZERO)))
49+
.map(|(&a, &b)| a + b)
50+
.take(D)
51+
.collect::<Vec<F>>()
52+
.try_into()
53+
.unwrap_or_else(|v: Vec<F>| panic!("Expected a Vec of length {} but it was {}", D, v.len()));
54+
Self::Output { coefficients, basis: self.basis }
55+
}
56+
}
57+
3758
impl<F: FiniteField, const D: usize, const D2: usize> AddAssign<Polynomial<Monomial, F, D2>>
3859
for Polynomial<Monomial, F, D>
3960
{

src/polynomial/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
//! - Includes Discrete Fourier Transform (DFT) for polynomials in the [`Monomial`] basis to convert
1818
//! into the [`Lagrange`] basis via evaluation at the roots of unity.
1919
20-
use std::array;
20+
use std::{array, fmt::Debug};
2121

2222
use super::*;
2323
use crate::algebra::field::FiniteField;
@@ -45,7 +45,7 @@ pub struct Polynomial<B: Basis, F: FiniteField, const D: usize> {
4545

4646
/// [`Basis`] trait is used to specify the basis of the polynomial.
4747
/// The basis can be [`Monomial`] or [`Lagrange`]. This is a type-state pattern for [`Polynomial`].
48-
pub trait Basis {
48+
pub trait Basis: Debug + Clone {
4949
/// The associated data type for the basis.
5050
type Data;
5151
}

0 commit comments

Comments
 (0)