6161 * Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L19]
6262 *
6363 **************************************************/
64- static void mlk_pack_pk (uint8_t r [MLKEM_INDCPA_PUBLICKEYBYTES ], mlk_polyvec pk ,
64+ static void mlk_pack_pk (uint8_t r [MLKEM_INDCPA_PUBLICKEYBYTES ],
65+ const mlk_polyvec * pk ,
6566 const uint8_t seed [MLKEM_SYMBYTES ])
6667{
6768 mlk_assert_bound_2d (pk -> vec , MLKEM_K , MLKEM_N , 0 , MLKEM_Q );
@@ -85,7 +86,7 @@ static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec pk,
8586 * Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L2-3]
8687 *
8788 **************************************************/
88- static void mlk_unpack_pk (mlk_polyvec pk , uint8_t seed [MLKEM_SYMBYTES ],
89+ static void mlk_unpack_pk (mlk_polyvec * pk , uint8_t seed [MLKEM_SYMBYTES ],
8990 const uint8_t packedpk [MLKEM_INDCPA_PUBLICKEYBYTES ])
9091{
9192 mlk_polyvec_frombytes (pk , packedpk );
@@ -110,7 +111,8 @@ static void mlk_unpack_pk(mlk_polyvec pk, uint8_t seed[MLKEM_SYMBYTES],
110111 * Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L20]
111112 *
112113 **************************************************/
113- static void mlk_pack_sk (uint8_t r [MLKEM_INDCPA_SECRETKEYBYTES ], mlk_polyvec sk )
114+ static void mlk_pack_sk (uint8_t r [MLKEM_INDCPA_SECRETKEYBYTES ],
115+ const mlk_polyvec * sk )
114116{
115117 mlk_assert_bound_2d (sk -> vec , MLKEM_K , MLKEM_N , 0 , MLKEM_Q );
116118 mlk_polyvec_tobytes (r , sk );
@@ -130,7 +132,7 @@ static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec sk)
130132 * Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L5]
131133 *
132134 **************************************************/
133- static void mlk_unpack_sk (mlk_polyvec sk ,
135+ static void mlk_unpack_sk (mlk_polyvec * sk ,
134136 const uint8_t packedsk [MLKEM_INDCPA_SECRETKEYBYTES ])
135137{
136138 mlk_polyvec_frombytes (sk , packedsk );
@@ -151,8 +153,8 @@ static void mlk_unpack_sk(mlk_polyvec sk,
151153 * Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L22-23]
152154 *
153155 **************************************************/
154- static void mlk_pack_ciphertext (uint8_t r [MLKEM_INDCPA_BYTES ], mlk_polyvec b ,
155- mlk_poly * v )
156+ static void mlk_pack_ciphertext (uint8_t r [MLKEM_INDCPA_BYTES ],
157+ const mlk_polyvec * b , mlk_poly * v )
156158{
157159 mlk_polyvec_compress_du (r , b );
158160 mlk_poly_compress_dv (r + MLKEM_POLYVECCOMPRESSEDBYTES_DU , v );
@@ -172,7 +174,7 @@ static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], mlk_polyvec b,
172174 * Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L1-4]
173175 *
174176 **************************************************/
175- static void mlk_unpack_ciphertext (mlk_polyvec b , mlk_poly * v ,
177+ static void mlk_unpack_ciphertext (mlk_polyvec * b , mlk_poly * v ,
176178 const uint8_t c [MLKEM_INDCPA_BYTES ])
177179{
178180 mlk_polyvec_decompress_du (b , c );
@@ -244,7 +246,7 @@ __contract__(
244246 *
245247 * Not static for benchmarking */
246248MLK_INTERNAL_API
247- void mlk_gen_matrix (mlk_polymat a , const uint8_t seed [MLKEM_SYMBYTES ],
249+ void mlk_gen_matrix (mlk_polymat * a , const uint8_t seed [MLKEM_SYMBYTES ],
248250 int transposed )
249251{
250252 unsigned i , j ;
@@ -277,7 +279,11 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
277279 }
278280 }
279281
280- mlk_poly_rej_uniform_x4 (& a [i ], & a [i + 1 ], & a [i + 2 ], & a [i + 3 ], seed_ext );
282+ mlk_poly_rej_uniform_x4 (& a -> vec [i / MLKEM_K ].vec [i % MLKEM_K ],
283+ & a -> vec [(i + 1 ) / MLKEM_K ].vec [(i + 1 ) % MLKEM_K ],
284+ & a -> vec [(i + 2 ) / MLKEM_K ].vec [(i + 2 ) % MLKEM_K ],
285+ & a -> vec [(i + 3 ) / MLKEM_K ].vec [(i + 3 ) % MLKEM_K ],
286+ seed_ext );
281287 }
282288#else /* !MLK_CONFIG_SERIAL_FIPS202_ONLY */
283289 /* When using serial FIPS202, sample all entries individually. */
@@ -305,7 +311,7 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
305311 seed_ext [0 ][MLKEM_SYMBYTES + 1 ] = x ;
306312 }
307313
308- mlk_poly_rej_uniform (& a [ i ], seed_ext [0 ]);
314+ mlk_poly_rej_uniform (& a -> vec [ i / MLKEM_K ]. vec [ i % MLKEM_K ], seed_ext [0 ]);
309315 }
310316
311317 mlk_assert (i == MLKEM_K * MLKEM_K );
@@ -340,15 +346,16 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
340346 * Specification: Implements @[FIPS203, Section 2.4.7, Eq (2.12), (2.13)]
341347 *
342348 **************************************************/
343- static void mlk_matvec_mul (mlk_polyvec out , const mlk_polymat a ,
344- const mlk_polyvec v , const mlk_polyvec_mulcache vc )
349+ static void mlk_matvec_mul (mlk_polyvec * out , const mlk_polymat * a ,
350+ const mlk_polyvec * v , const mlk_polyvec_mulcache * vc )
345351__contract__ (
346352 requires (memory_no_alias (out , sizeof (mlk_polyvec )))
347353 requires (memory_no_alias (a , sizeof (mlk_polymat )))
348354 requires (memory_no_alias (v , sizeof (mlk_polyvec )))
349355 requires (memory_no_alias (vc , sizeof (mlk_polyvec_mulcache )))
350- requires (forall (k0 , 0 , MLKEM_K * MLKEM_K ,
351- array_bound (a [k0 ].coeffs , 0 , MLKEM_N , 0 , MLKEM_UINT12_LIMIT )))
356+ requires (forall (k0 , 0 , MLKEM_K ,
357+ forall (k1 , 0 , MLKEM_K ,
358+ array_bound (a - > vec [k0 ].vec [k1 ].coeffs , 0 , MLKEM_N , 0 , MLKEM_UINT12_LIMIT ))))
352359 requires (forall (k1 , 0 , MLKEM_K ,
353360 array_abs_bound (v [k1 ].coeffs , 0 , MLKEM_N , MLK_NTT_BOUND )))
354361 requires (forall (k2 , 0 , MLKEM_K ,
@@ -357,32 +364,17 @@ __contract__(
357364 ensures (forall (k3 , 0 , MLKEM_K ,
358365 array_abs_bound (out [k3 ].coeffs , 0 , MLKEM_N , INT16_MAX /2 ))))
359366{
360- /* Temporary on the "refine-bounds" branch - unroll to a simple
361- * sequence of calls for each possible value of MLKEM_K to
362- * simplify proof.
363- */
364- mlk_polyvec_basemul_acc_montgomery_cached (& out [0 ], & a [0 ], v , vc );
365- mlk_polyvec_basemul_acc_montgomery_cached (& out [1 ], & a [MLKEM_K ], v , vc );
366-
367- #if MLKEM_K == 3
368- mlk_polyvec_basemul_acc_montgomery_cached (& out [2 ], & a [MLKEM_K * 2 ], v , vc );
369- #elif MLKEM_K == 4
370- mlk_polyvec_basemul_acc_montgomery_cached (& out [2 ], & a [MLKEM_K * 2 ], v , vc );
371- mlk_polyvec_basemul_acc_montgomery_cached (& out [3 ], & a [MLKEM_K * 3 ], v , vc );
372- #endif
373-
374- /* unsigned i;
375- * for (i = 0; i < MLKEM_K; i++)
376- * __loop__(
377- * assigns(i, object_whole(out))
378- * invariant(i <= MLKEM_K)
379- * invariant(forall(k, 0, i,
380- * array_abs_bound(out[k].coeffs, 0, MLKEM_N, INT16_MAX/2))))
381- * {
382- * mlk_polyvec_basemul_acc_montgomery_cached(&out[i], &a[MLKEM_K * i], v,
383- * vc);
384- * }
385- */
367+ unsigned i ;
368+ for (i = 0 ; i < MLKEM_K ; i ++ )
369+ __loop__ (
370+ assigns (i , object_whole (out ))
371+ invariant (i <= MLKEM_K ))
372+ invariant (forall (k , 0 , i ,
373+ array_abs_bound (out [k ].coeffs , 0 , MLKEM_N , INT16_MAX /2 ))))
374+ {
375+ mlk_polyvec_basemul_acc_montgomery_cached (& out -> vec [i ], & a -> vec [i ], v ,
376+ vc );
377+ }
386378}
387379
388380/* Reference: `indcpa_keypair_derand()` in the reference implementation @[REF].
@@ -433,7 +425,7 @@ int mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
433425 */
434426 MLK_CT_TESTING_DECLASSIFY (publicseed , MLKEM_SYMBYTES );
435427
436- mlk_gen_matrix (a , publicseed , 0 /* no transpose */ );
428+ mlk_gen_matrix (& a , publicseed , 0 /* no transpose */ );
437429
438430#if MLKEM_K == 2
439431 mlk_poly_getnoise_eta1_4x (& skpv -> vec [0 ], & skpv -> vec [1 ], & e -> vec [0 ],
@@ -457,19 +449,19 @@ int mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
457449 noiseseed , 4 , 5 , 6 , 7 );
458450#endif /* MLKEM_K == 4 */
459451
460- mlk_polyvec_ntt (skpv );
461- mlk_polyvec_ntt (e );
452+ mlk_polyvec_ntt (& skpv );
453+ mlk_polyvec_ntt (& e );
462454
463- mlk_polyvec_mulcache_compute (skpv_cache , skpv );
464- mlk_matvec_mul (pkpv , a , skpv , skpv_cache );
465- mlk_polyvec_tomont (pkpv );
455+ mlk_polyvec_mulcache_compute (& skpv_cache , & skpv );
456+ mlk_matvec_mul (& pkpv , & a , & skpv , & skpv_cache );
457+ mlk_polyvec_tomont (& pkpv );
466458
467- mlk_polyvec_add (pkpv , e );
468- mlk_polyvec_reduce (pkpv );
469- mlk_polyvec_reduce (skpv );
459+ mlk_polyvec_add (& pkpv , & e );
460+ mlk_polyvec_reduce (& pkpv );
461+ mlk_polyvec_reduce (& skpv );
470462
471- mlk_pack_sk (sk , skpv );
472- mlk_pack_pk (pk , pkpv , publicseed );
463+ mlk_pack_sk (sk , & skpv );
464+ mlk_pack_pk (pk , & pkpv , publicseed );
473465
474466cleanup :
475467 /* Specification: Partially implements
@@ -528,7 +520,7 @@ int mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
528520 */
529521 MLK_CT_TESTING_DECLASSIFY (seed , MLKEM_SYMBYTES );
530522
531- mlk_gen_matrix (at , seed , 1 /* transpose */ );
523+ mlk_gen_matrix (& at , seed , 1 /* transpose */ );
532524
533525#if MLKEM_K == 2
534526 mlk_poly_getnoise_eta1122_4x (& sp -> vec [0 ], & sp -> vec [1 ], & ep -> vec [0 ],
@@ -552,7 +544,7 @@ int mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
552544 mlk_poly_getnoise_eta2 (epp , coins , 8 );
553545#endif /* MLKEM_K == 4 */
554546
555- mlk_polyvec_ntt (sp );
547+ mlk_polyvec_ntt (& sp );
556548
557549 mlk_polyvec_mulcache_compute (sp_cache , sp );
558550 mlk_matvec_mul (b , at , sp , sp_cache );
0 commit comments