Skip to content

Commit dd334d9

Browse files
committed
module-lattice: fix ci failures
1 parent c138f6b commit dd334d9

File tree

2 files changed

+41
-83
lines changed

2 files changed

+41
-83
lines changed

module-lattice/tests/algebra.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ fn ntt_polynomial_from_array() {
322322

323323
let coeffs: Array<Elem<KyberField>, hybrid_array::typenum::U256> =
324324
core::array::from_fn(|i| Elem::new((i % 3329) as u16)).into();
325-
let p: NttPolynomial<KyberField> = coeffs.clone().into();
325+
let p: NttPolynomial<KyberField> = coeffs.into();
326326

327327
assert_eq!(p.0[0].0, 0);
328328
assert_eq!(p.0[1].0, 1);

module-lattice/tests/encode.rs

Lines changed: 40 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use hybrid_array::typenum::{U1, U4, U10, U12};
44
use module_lattice::algebra::{Elem, NttPolynomial, NttVector, Polynomial, Vector};
5-
use module_lattice::encode::{byte_decode, byte_encode, Encode};
5+
use module_lattice::encode::{Encode, byte_decode, byte_encode};
66

77
// Field used by ML-KEM.
88
module_lattice::define_field!(KyberField, u16, u32, u64, 3329);
@@ -14,88 +14,69 @@ module_lattice::define_field!(KyberField, u16, u32, u64, 3329);
1414
#[test]
1515
fn byte_encode_decode_d1_roundtrip() {
1616
// D=1: Single bit encoding
17-
let mut vals = [Elem::<KyberField>::new(0); 256];
18-
for i in 0..256 {
19-
vals[i] = Elem::new((i % 2) as u16);
20-
}
17+
let vals: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i % 2) as u16));
2118

2219
let encoded = byte_encode::<KyberField, U1>(&vals.into());
2320
let decoded = byte_decode::<KyberField, U1>(&encoded);
2421

25-
for i in 0..256 {
26-
assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}");
22+
for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() {
23+
assert_eq!(dec.0, val.0, "Mismatch at index {i}");
2724
}
2825
}
2926

3027
#[test]
3128
fn byte_encode_decode_d4_roundtrip() {
3229
// D=4: 4-bit encoding
33-
let mut vals = [Elem::<KyberField>::new(0); 256];
34-
for i in 0..256 {
35-
vals[i] = Elem::new((i % 16) as u16);
36-
}
30+
let vals: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i % 16) as u16));
3731

3832
let encoded = byte_encode::<KyberField, U4>(&vals.into());
3933
let decoded = byte_decode::<KyberField, U4>(&encoded);
4034

41-
for i in 0..256 {
42-
assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}");
35+
for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() {
36+
assert_eq!(dec.0, val.0, "Mismatch at index {i}");
4337
}
4438
}
4539

4640
#[test]
4741
fn byte_encode_decode_d10_roundtrip() {
4842
// D=10: 10-bit encoding
49-
let mut vals = [Elem::<KyberField>::new(0); 256];
50-
for i in 0..256 {
51-
vals[i] = Elem::new((i % 1024) as u16);
52-
}
43+
let vals: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i % 1024) as u16));
5344

5445
let encoded = byte_encode::<KyberField, U10>(&vals.into());
5546
let decoded = byte_decode::<KyberField, U10>(&encoded);
5647

57-
for i in 0..256 {
58-
assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}");
48+
for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() {
49+
assert_eq!(dec.0, val.0, "Mismatch at index {i}");
5950
}
6051
}
6152

6253
#[test]
6354
fn byte_encode_decode_d12_roundtrip() {
6455
// D=12: 12-bit encoding (special case with modular reduction)
65-
let mut vals = [Elem::<KyberField>::new(0); 256];
66-
for i in 0..256 {
67-
// Values up to q-1 (3328)
68-
vals[i] = Elem::new((i * 13) as u16 % 3329);
69-
}
56+
// Values up to q-1 (3328)
57+
let vals: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329));
7058

7159
let encoded = byte_encode::<KyberField, U12>(&vals.into());
7260
let decoded = byte_decode::<KyberField, U12>(&encoded);
7361

74-
for i in 0..256 {
75-
assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}");
62+
for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() {
63+
assert_eq!(dec.0, val.0, "Mismatch at index {i}");
7664
}
7765
}
7866

7967
#[test]
8068
fn byte_encode_decode_d12_modular_reduction() {
8169
// Test that D=12 properly reduces values >= Q
82-
let mut vals = [Elem::<KyberField>::new(0); 256];
83-
8470
// Fill with values near and above Q
85-
for i in 0..256 {
86-
vals[i] = Elem::new(3329 + (i as u16) % 100); // Values >= Q
87-
}
71+
let vals: [Elem<KyberField>; 256] =
72+
core::array::from_fn(|i| Elem::new(3329 + (i as u16) % 100));
8873

8974
let encoded = byte_encode::<KyberField, U12>(&vals.into());
9075
let decoded = byte_decode::<KyberField, U12>(&encoded);
9176

9277
// After decode, values should be reduced mod Q
93-
for i in 0..256 {
94-
assert!(
95-
decoded[i].0 < 3329,
96-
"Value at {i} not reduced: {}",
97-
decoded[i].0
98-
);
78+
for (i, dec) in decoded.iter().enumerate() {
79+
assert!(dec.0 < 3329, "Value at {i} not reduced: {}", dec.0);
9980
}
10081
}
10182

@@ -106,8 +87,8 @@ fn byte_encode_zero_values() {
10687
let encoded = byte_encode::<KyberField, U4>(&vals.into());
10788
let decoded = byte_decode::<KyberField, U4>(&encoded);
10889

109-
for i in 0..256 {
110-
assert_eq!(decoded[i].0, 0);
90+
for dec in &decoded {
91+
assert_eq!(dec.0, 0);
11192
}
11293
}
11394

@@ -119,8 +100,8 @@ fn byte_encode_max_values() {
119100
let encoded = byte_encode::<KyberField, U4>(&vals.into());
120101
let decoded = byte_decode::<KyberField, U4>(&encoded);
121102

122-
for i in 0..256 {
123-
assert_eq!(decoded[i].0, 15);
103+
for dec in &decoded {
104+
assert_eq!(dec.0, 15);
124105
}
125106
}
126107

@@ -130,10 +111,7 @@ fn byte_encode_max_values() {
130111

131112
#[test]
132113
fn polynomial_encode_decode_roundtrip() {
133-
let mut coeffs = [Elem::<KyberField>::new(0); 256];
134-
for i in 0..256 {
135-
coeffs[i] = Elem::new((i * 7) as u16 % 16);
136-
}
114+
let coeffs: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i * 7) as u16 % 16));
137115
let p = Polynomial::<KyberField>::new(coeffs.into());
138116

139117
let encoded = <Polynomial<KyberField> as Encode<U4>>::encode(&p);
@@ -144,10 +122,8 @@ fn polynomial_encode_decode_roundtrip() {
144122

145123
#[test]
146124
fn polynomial_encode_decode_d12() {
147-
let mut coeffs = [Elem::<KyberField>::new(0); 256];
148-
for i in 0..256 {
149-
coeffs[i] = Elem::new((i * 13) as u16 % 3329);
150-
}
125+
let coeffs: [Elem<KyberField>; 256] =
126+
core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329));
151127
let p = Polynomial::<KyberField>::new(coeffs.into());
152128

153129
let encoded = <Polynomial<KyberField> as Encode<U12>>::encode(&p);
@@ -164,12 +140,8 @@ fn polynomial_encode_decode_d12() {
164140
fn vector_encode_decode_roundtrip() {
165141
use hybrid_array::typenum::U2;
166142

167-
let mut coeffs1 = [Elem::<KyberField>::new(0); 256];
168-
let mut coeffs2 = [Elem::<KyberField>::new(0); 256];
169-
for i in 0..256 {
170-
coeffs1[i] = Elem::new((i * 3) as u16 % 16);
171-
coeffs2[i] = Elem::new((i * 5) as u16 % 16);
172-
}
143+
let coeffs1: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16));
144+
let coeffs2: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16));
173145

174146
let p1 = Polynomial::<KyberField>::new(coeffs1.into());
175147
let p2 = Polynomial::<KyberField>::new(coeffs2.into());
@@ -187,10 +159,7 @@ fn vector_encode_decode_roundtrip() {
187159

188160
#[test]
189161
fn ntt_polynomial_encode_decode_roundtrip() {
190-
let mut coeffs = [Elem::<KyberField>::new(0); 256];
191-
for i in 0..256 {
192-
coeffs[i] = Elem::new((i * 7) as u16 % 16);
193-
}
162+
let coeffs: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i * 7) as u16 % 16));
194163
let p = NttPolynomial::<KyberField>::new(coeffs.into());
195164

196165
let encoded = <NttPolynomial<KyberField> as Encode<U4>>::encode(&p);
@@ -201,10 +170,8 @@ fn ntt_polynomial_encode_decode_roundtrip() {
201170

202171
#[test]
203172
fn ntt_polynomial_encode_decode_d12() {
204-
let mut coeffs = [Elem::<KyberField>::new(0); 256];
205-
for i in 0..256 {
206-
coeffs[i] = Elem::new((i * 13) as u16 % 3329);
207-
}
173+
let coeffs: [Elem<KyberField>; 256] =
174+
core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329));
208175
let p = NttPolynomial::<KyberField>::new(coeffs.into());
209176

210177
let encoded = <NttPolynomial<KyberField> as Encode<U12>>::encode(&p);
@@ -221,12 +188,8 @@ fn ntt_polynomial_encode_decode_d12() {
221188
fn ntt_vector_encode_decode_roundtrip() {
222189
use hybrid_array::typenum::U2;
223190

224-
let mut coeffs1 = [Elem::<KyberField>::new(0); 256];
225-
let mut coeffs2 = [Elem::<KyberField>::new(0); 256];
226-
for i in 0..256 {
227-
coeffs1[i] = Elem::new((i * 3) as u16 % 16);
228-
coeffs2[i] = Elem::new((i * 5) as u16 % 16);
229-
}
191+
let coeffs1: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16));
192+
let coeffs2: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16));
230193

231194
let p1 = NttPolynomial::<KyberField>::new(coeffs1.into());
232195
let p2 = NttPolynomial::<KyberField>::new(coeffs2.into());
@@ -269,7 +232,7 @@ fn encoded_vector_size() {
269232
// D=4, K=3: 128 bytes per polynomial * 3 = 384 bytes
270233
let coeffs = [Elem::<KyberField>::new(0); 256];
271234
let p = Polynomial::<KyberField>::new(coeffs.into());
272-
let v: Vector<KyberField, U3> = Vector::new([p.clone(), p.clone(), p].into());
235+
let v: Vector<KyberField, U3> = Vector::new([p, p, p].into());
273236

274237
let encoded = <Vector<KyberField, U3> as Encode<U4>>::encode(&v);
275238
assert_eq!(encoded.len(), 384);
@@ -282,31 +245,26 @@ fn encoded_vector_size() {
282245
#[test]
283246
fn byte_encode_alternating_bits() {
284247
// Test alternating patterns to catch bit manipulation issues
285-
let mut vals = [Elem::<KyberField>::new(0); 256];
286-
for i in 0..256 {
287-
vals[i] = Elem::new(if i % 2 == 0 { 0b0101 } else { 0b1010 });
288-
}
248+
let vals: [Elem<KyberField>; 256] =
249+
core::array::from_fn(|i| Elem::new(if i % 2 == 0 { 0b0101 } else { 0b1010 }));
289250

290251
let encoded = byte_encode::<KyberField, U4>(&vals.into());
291252
let decoded = byte_decode::<KyberField, U4>(&encoded);
292253

293-
for i in 0..256 {
294-
assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}");
254+
for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() {
255+
assert_eq!(dec.0, val.0, "Mismatch at index {i}");
295256
}
296257
}
297258

298259
#[test]
299260
fn byte_encode_sequential_values() {
300261
// Sequential values to catch ordering issues
301-
let mut vals = [Elem::<KyberField>::new(0); 256];
302-
for i in 0..256 {
303-
vals[i] = Elem::new(i as u16 % 16);
304-
}
262+
let vals: [Elem<KyberField>; 256] = core::array::from_fn(|i| Elem::new(i as u16 % 16));
305263

306264
let encoded = byte_encode::<KyberField, U4>(&vals.into());
307265
let decoded = byte_decode::<KyberField, U4>(&encoded);
308266

309-
for i in 0..256 {
310-
assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}");
267+
for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() {
268+
assert_eq!(dec.0, val.0, "Mismatch at index {i}");
311269
}
312270
}

0 commit comments

Comments
 (0)