22
33use hybrid_array:: typenum:: { U1 , U4 , U10 , U12 } ;
44use module_lattice:: algebra:: { Elem , NttPolynomial , NttVector , Polynomial , Vector } ;
5- use module_lattice:: encode:: { byte_decode , byte_encode , Encode } ;
5+ use module_lattice:: encode:: { Encode , byte_decode , byte_encode } ;
66
77// Field used by ML-KEM.
88module_lattice:: define_field!( KyberField , u16 , u32 , u64 , 3329 ) ;
@@ -14,88 +14,69 @@ module_lattice::define_field!(KyberField, u16, u32, u64, 3329);
1414#[ test]
1515fn byte_encode_decode_d1_roundtrip ( ) {
1616 // D=1: Single bit encoding
17- let mut vals = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
18- for i in 0 ..256 {
19- vals[ i] = Elem :: new ( ( i % 2 ) as u16 ) ;
20- }
17+ let vals: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i % 2 ) as u16 ) ) ;
2118
2219 let encoded = byte_encode :: < KyberField , U1 > ( & vals. into ( ) ) ;
2320 let decoded = byte_decode :: < KyberField , U1 > ( & encoded) ;
2421
25- for i in 0 .. 256 {
26- assert_eq ! ( decoded [ i ] . 0 , vals [ i ] . 0 , "Mismatch at index {i}" ) ;
22+ for ( i , ( dec , val ) ) in decoded . iter ( ) . zip ( vals . iter ( ) ) . enumerate ( ) {
23+ assert_eq ! ( dec . 0 , val . 0 , "Mismatch at index {i}" ) ;
2724 }
2825}
2926
3027#[ test]
3128fn byte_encode_decode_d4_roundtrip ( ) {
3229 // D=4: 4-bit encoding
33- let mut vals = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
34- for i in 0 ..256 {
35- vals[ i] = Elem :: new ( ( i % 16 ) as u16 ) ;
36- }
30+ let vals: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i % 16 ) as u16 ) ) ;
3731
3832 let encoded = byte_encode :: < KyberField , U4 > ( & vals. into ( ) ) ;
3933 let decoded = byte_decode :: < KyberField , U4 > ( & encoded) ;
4034
41- for i in 0 .. 256 {
42- assert_eq ! ( decoded [ i ] . 0 , vals [ i ] . 0 , "Mismatch at index {i}" ) ;
35+ for ( i , ( dec , val ) ) in decoded . iter ( ) . zip ( vals . iter ( ) ) . enumerate ( ) {
36+ assert_eq ! ( dec . 0 , val . 0 , "Mismatch at index {i}" ) ;
4337 }
4438}
4539
4640#[ test]
4741fn byte_encode_decode_d10_roundtrip ( ) {
4842 // D=10: 10-bit encoding
49- let mut vals = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
50- for i in 0 ..256 {
51- vals[ i] = Elem :: new ( ( i % 1024 ) as u16 ) ;
52- }
43+ let vals: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i % 1024 ) as u16 ) ) ;
5344
5445 let encoded = byte_encode :: < KyberField , U10 > ( & vals. into ( ) ) ;
5546 let decoded = byte_decode :: < KyberField , U10 > ( & encoded) ;
5647
57- for i in 0 .. 256 {
58- assert_eq ! ( decoded [ i ] . 0 , vals [ i ] . 0 , "Mismatch at index {i}" ) ;
48+ for ( i , ( dec , val ) ) in decoded . iter ( ) . zip ( vals . iter ( ) ) . enumerate ( ) {
49+ assert_eq ! ( dec . 0 , val . 0 , "Mismatch at index {i}" ) ;
5950 }
6051}
6152
6253#[ test]
6354fn byte_encode_decode_d12_roundtrip ( ) {
6455 // D=12: 12-bit encoding (special case with modular reduction)
65- let mut vals = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
66- for i in 0 ..256 {
67- // Values up to q-1 (3328)
68- vals[ i] = Elem :: new ( ( i * 13 ) as u16 % 3329 ) ;
69- }
56+ // Values up to q-1 (3328)
57+ let vals: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i * 13 ) as u16 % 3329 ) ) ;
7058
7159 let encoded = byte_encode :: < KyberField , U12 > ( & vals. into ( ) ) ;
7260 let decoded = byte_decode :: < KyberField , U12 > ( & encoded) ;
7361
74- for i in 0 .. 256 {
75- assert_eq ! ( decoded [ i ] . 0 , vals [ i ] . 0 , "Mismatch at index {i}" ) ;
62+ for ( i , ( dec , val ) ) in decoded . iter ( ) . zip ( vals . iter ( ) ) . enumerate ( ) {
63+ assert_eq ! ( dec . 0 , val . 0 , "Mismatch at index {i}" ) ;
7664 }
7765}
7866
7967#[ test]
8068fn byte_encode_decode_d12_modular_reduction ( ) {
8169 // Test that D=12 properly reduces values >= Q
82- let mut vals = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
83-
8470 // Fill with values near and above Q
85- for i in 0 ..256 {
86- vals[ i] = Elem :: new ( 3329 + ( i as u16 ) % 100 ) ; // Values >= Q
87- }
71+ let vals: [ Elem < KyberField > ; 256 ] =
72+ core:: array:: from_fn ( |i| Elem :: new ( 3329 + ( i as u16 ) % 100 ) ) ;
8873
8974 let encoded = byte_encode :: < KyberField , U12 > ( & vals. into ( ) ) ;
9075 let decoded = byte_decode :: < KyberField , U12 > ( & encoded) ;
9176
9277 // After decode, values should be reduced mod Q
93- for i in 0 ..256 {
94- assert ! (
95- decoded[ i] . 0 < 3329 ,
96- "Value at {i} not reduced: {}" ,
97- decoded[ i] . 0
98- ) ;
78+ for ( i, dec) in decoded. iter ( ) . enumerate ( ) {
79+ assert ! ( dec. 0 < 3329 , "Value at {i} not reduced: {}" , dec. 0 ) ;
9980 }
10081}
10182
@@ -106,8 +87,8 @@ fn byte_encode_zero_values() {
10687 let encoded = byte_encode :: < KyberField , U4 > ( & vals. into ( ) ) ;
10788 let decoded = byte_decode :: < KyberField , U4 > ( & encoded) ;
10889
109- for i in 0 .. 256 {
110- assert_eq ! ( decoded [ i ] . 0 , 0 ) ;
90+ for dec in & decoded {
91+ assert_eq ! ( dec . 0 , 0 ) ;
11192 }
11293}
11394
@@ -119,8 +100,8 @@ fn byte_encode_max_values() {
119100 let encoded = byte_encode :: < KyberField , U4 > ( & vals. into ( ) ) ;
120101 let decoded = byte_decode :: < KyberField , U4 > ( & encoded) ;
121102
122- for i in 0 .. 256 {
123- assert_eq ! ( decoded [ i ] . 0 , 15 ) ;
103+ for dec in & decoded {
104+ assert_eq ! ( dec . 0 , 15 ) ;
124105 }
125106}
126107
@@ -130,10 +111,7 @@ fn byte_encode_max_values() {
130111
131112#[ test]
132113fn polynomial_encode_decode_roundtrip ( ) {
133- let mut coeffs = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
134- for i in 0 ..256 {
135- coeffs[ i] = Elem :: new ( ( i * 7 ) as u16 % 16 ) ;
136- }
114+ let coeffs: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i * 7 ) as u16 % 16 ) ) ;
137115 let p = Polynomial :: < KyberField > :: new ( coeffs. into ( ) ) ;
138116
139117 let encoded = <Polynomial < KyberField > as Encode < U4 > >:: encode ( & p) ;
@@ -144,10 +122,8 @@ fn polynomial_encode_decode_roundtrip() {
144122
145123#[ test]
146124fn polynomial_encode_decode_d12 ( ) {
147- let mut coeffs = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
148- for i in 0 ..256 {
149- coeffs[ i] = Elem :: new ( ( i * 13 ) as u16 % 3329 ) ;
150- }
125+ let coeffs: [ Elem < KyberField > ; 256 ] =
126+ core:: array:: from_fn ( |i| Elem :: new ( ( i * 13 ) as u16 % 3329 ) ) ;
151127 let p = Polynomial :: < KyberField > :: new ( coeffs. into ( ) ) ;
152128
153129 let encoded = <Polynomial < KyberField > as Encode < U12 > >:: encode ( & p) ;
@@ -164,12 +140,8 @@ fn polynomial_encode_decode_d12() {
164140fn vector_encode_decode_roundtrip ( ) {
165141 use hybrid_array:: typenum:: U2 ;
166142
167- let mut coeffs1 = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
168- let mut coeffs2 = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
169- for i in 0 ..256 {
170- coeffs1[ i] = Elem :: new ( ( i * 3 ) as u16 % 16 ) ;
171- coeffs2[ i] = Elem :: new ( ( i * 5 ) as u16 % 16 ) ;
172- }
143+ let coeffs1: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i * 3 ) as u16 % 16 ) ) ;
144+ let coeffs2: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i * 5 ) as u16 % 16 ) ) ;
173145
174146 let p1 = Polynomial :: < KyberField > :: new ( coeffs1. into ( ) ) ;
175147 let p2 = Polynomial :: < KyberField > :: new ( coeffs2. into ( ) ) ;
@@ -187,10 +159,7 @@ fn vector_encode_decode_roundtrip() {
187159
188160#[ test]
189161fn ntt_polynomial_encode_decode_roundtrip ( ) {
190- let mut coeffs = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
191- for i in 0 ..256 {
192- coeffs[ i] = Elem :: new ( ( i * 7 ) as u16 % 16 ) ;
193- }
162+ let coeffs: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i * 7 ) as u16 % 16 ) ) ;
194163 let p = NttPolynomial :: < KyberField > :: new ( coeffs. into ( ) ) ;
195164
196165 let encoded = <NttPolynomial < KyberField > as Encode < U4 > >:: encode ( & p) ;
@@ -201,10 +170,8 @@ fn ntt_polynomial_encode_decode_roundtrip() {
201170
202171#[ test]
203172fn ntt_polynomial_encode_decode_d12 ( ) {
204- let mut coeffs = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
205- for i in 0 ..256 {
206- coeffs[ i] = Elem :: new ( ( i * 13 ) as u16 % 3329 ) ;
207- }
173+ let coeffs: [ Elem < KyberField > ; 256 ] =
174+ core:: array:: from_fn ( |i| Elem :: new ( ( i * 13 ) as u16 % 3329 ) ) ;
208175 let p = NttPolynomial :: < KyberField > :: new ( coeffs. into ( ) ) ;
209176
210177 let encoded = <NttPolynomial < KyberField > as Encode < U12 > >:: encode ( & p) ;
@@ -221,12 +188,8 @@ fn ntt_polynomial_encode_decode_d12() {
221188fn ntt_vector_encode_decode_roundtrip ( ) {
222189 use hybrid_array:: typenum:: U2 ;
223190
224- let mut coeffs1 = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
225- let mut coeffs2 = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
226- for i in 0 ..256 {
227- coeffs1[ i] = Elem :: new ( ( i * 3 ) as u16 % 16 ) ;
228- coeffs2[ i] = Elem :: new ( ( i * 5 ) as u16 % 16 ) ;
229- }
191+ let coeffs1: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i * 3 ) as u16 % 16 ) ) ;
192+ let coeffs2: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( ( i * 5 ) as u16 % 16 ) ) ;
230193
231194 let p1 = NttPolynomial :: < KyberField > :: new ( coeffs1. into ( ) ) ;
232195 let p2 = NttPolynomial :: < KyberField > :: new ( coeffs2. into ( ) ) ;
@@ -269,7 +232,7 @@ fn encoded_vector_size() {
269232 // D=4, K=3: 128 bytes per polynomial * 3 = 384 bytes
270233 let coeffs = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
271234 let p = Polynomial :: < KyberField > :: new ( coeffs. into ( ) ) ;
272- let v: Vector < KyberField , U3 > = Vector :: new ( [ p. clone ( ) , p. clone ( ) , p] . into ( ) ) ;
235+ let v: Vector < KyberField , U3 > = Vector :: new ( [ p, p, p] . into ( ) ) ;
273236
274237 let encoded = <Vector < KyberField , U3 > as Encode < U4 > >:: encode ( & v) ;
275238 assert_eq ! ( encoded. len( ) , 384 ) ;
@@ -282,31 +245,26 @@ fn encoded_vector_size() {
282245#[ test]
283246fn byte_encode_alternating_bits ( ) {
284247 // Test alternating patterns to catch bit manipulation issues
285- let mut vals = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
286- for i in 0 ..256 {
287- vals[ i] = Elem :: new ( if i % 2 == 0 { 0b0101 } else { 0b1010 } ) ;
288- }
248+ let vals: [ Elem < KyberField > ; 256 ] =
249+ core:: array:: from_fn ( |i| Elem :: new ( if i % 2 == 0 { 0b0101 } else { 0b1010 } ) ) ;
289250
290251 let encoded = byte_encode :: < KyberField , U4 > ( & vals. into ( ) ) ;
291252 let decoded = byte_decode :: < KyberField , U4 > ( & encoded) ;
292253
293- for i in 0 .. 256 {
294- assert_eq ! ( decoded [ i ] . 0 , vals [ i ] . 0 , "Mismatch at index {i}" ) ;
254+ for ( i , ( dec , val ) ) in decoded . iter ( ) . zip ( vals . iter ( ) ) . enumerate ( ) {
255+ assert_eq ! ( dec . 0 , val . 0 , "Mismatch at index {i}" ) ;
295256 }
296257}
297258
298259#[ test]
299260fn byte_encode_sequential_values ( ) {
300261 // Sequential values to catch ordering issues
301- let mut vals = [ Elem :: < KyberField > :: new ( 0 ) ; 256 ] ;
302- for i in 0 ..256 {
303- vals[ i] = Elem :: new ( i as u16 % 16 ) ;
304- }
262+ let vals: [ Elem < KyberField > ; 256 ] = core:: array:: from_fn ( |i| Elem :: new ( i as u16 % 16 ) ) ;
305263
306264 let encoded = byte_encode :: < KyberField , U4 > ( & vals. into ( ) ) ;
307265 let decoded = byte_decode :: < KyberField , U4 > ( & encoded) ;
308266
309- for i in 0 .. 256 {
310- assert_eq ! ( decoded [ i ] . 0 , vals [ i ] . 0 , "Mismatch at index {i}" ) ;
267+ for ( i , ( dec , val ) ) in decoded . iter ( ) . zip ( vals . iter ( ) ) . enumerate ( ) {
268+ assert_eq ! ( dec . 0 , val . 0 , "Mismatch at index {i}" ) ;
311269 }
312270}
0 commit comments