Skip to content

Commit 5298417

Browse files
hanno-beckerrod-chapman
authored andcommitted
Switch mlk_polyvec and mlk_polymat to struct wrappers
- Change mlk_polyvec back to struct `{ mlk_poly vec[MLKEM_K]; }` - Change mlk_polymat to struct `{ mlk_polyvec vec[MLKEM_K]; }` - Update all function signatures to use pointer style - Fix all implementations to use struct member access - Update tests, benchmarks, and CBMC harnesses - Add consistent const annotations Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent ad2e3b7 commit 5298417

20 files changed

Lines changed: 125 additions & 133 deletions

File tree

mlkem/src/indcpa.c

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@
6161
* Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L19]
6262
*
6363
**************************************************/
64-
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec pk,
64+
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES],
65+
const mlk_polyvec *pk,
6566
const uint8_t seed[MLKEM_SYMBYTES])
6667
{
6768
mlk_assert_bound_2d(pk->vec, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
@@ -85,7 +86,7 @@ static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec pk,
8586
* Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L2-3]
8687
*
8788
**************************************************/
88-
static void mlk_unpack_pk(mlk_polyvec pk, uint8_t seed[MLKEM_SYMBYTES],
89+
static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
8990
const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES])
9091
{
9192
mlk_polyvec_frombytes(pk, packedpk);
@@ -110,7 +111,8 @@ static void mlk_unpack_pk(mlk_polyvec pk, uint8_t seed[MLKEM_SYMBYTES],
110111
* Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L20]
111112
*
112113
**************************************************/
113-
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec sk)
114+
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES],
115+
const mlk_polyvec *sk)
114116
{
115117
mlk_assert_bound_2d(sk->vec, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
116118
mlk_polyvec_tobytes(r, sk);
@@ -130,7 +132,7 @@ static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec sk)
130132
* Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L5]
131133
*
132134
**************************************************/
133-
static void mlk_unpack_sk(mlk_polyvec sk,
135+
static void mlk_unpack_sk(mlk_polyvec *sk,
134136
const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES])
135137
{
136138
mlk_polyvec_frombytes(sk, packedsk);
@@ -151,8 +153,8 @@ static void mlk_unpack_sk(mlk_polyvec sk,
151153
* Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L22-23]
152154
*
153155
**************************************************/
154-
static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], mlk_polyvec b,
155-
mlk_poly *v)
156+
static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES],
157+
const mlk_polyvec *b, mlk_poly *v)
156158
{
157159
mlk_polyvec_compress_du(r, b);
158160
mlk_poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v);
@@ -172,7 +174,7 @@ static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], mlk_polyvec b,
172174
* Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L1-4]
173175
*
174176
**************************************************/
175-
static void mlk_unpack_ciphertext(mlk_polyvec b, mlk_poly *v,
177+
static void mlk_unpack_ciphertext(mlk_polyvec *b, mlk_poly *v,
176178
const uint8_t c[MLKEM_INDCPA_BYTES])
177179
{
178180
mlk_polyvec_decompress_du(b, c);
@@ -244,7 +246,7 @@ __contract__(
244246
*
245247
* Not static for benchmarking */
246248
MLK_INTERNAL_API
247-
void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
249+
void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
248250
int transposed)
249251
{
250252
unsigned i, j;
@@ -277,7 +279,11 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
277279
}
278280
}
279281

280-
mlk_poly_rej_uniform_x4(&a[i], &a[i + 1], &a[i + 2], &a[i + 3], seed_ext);
282+
mlk_poly_rej_uniform_x4(&a->vec[i / MLKEM_K].vec[i % MLKEM_K],
283+
&a->vec[(i + 1) / MLKEM_K].vec[(i + 1) % MLKEM_K],
284+
&a->vec[(i + 2) / MLKEM_K].vec[(i + 2) % MLKEM_K],
285+
&a->vec[(i + 3) / MLKEM_K].vec[(i + 3) % MLKEM_K],
286+
seed_ext);
281287
}
282288
#else /* !MLK_CONFIG_SERIAL_FIPS202_ONLY */
283289
/* When using serial FIPS202, sample all entries individually. */
@@ -305,7 +311,7 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
305311
seed_ext[0][MLKEM_SYMBYTES + 1] = x;
306312
}
307313

308-
mlk_poly_rej_uniform(&a[i], seed_ext[0]);
314+
mlk_poly_rej_uniform(&a->vec[i / MLKEM_K].vec[i % MLKEM_K], seed_ext[0]);
309315
}
310316

311317
mlk_assert(i == MLKEM_K * MLKEM_K);
@@ -340,15 +346,16 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
340346
* Specification: Implements @[FIPS203, Section 2.4.7, Eq (2.12), (2.13)]
341347
*
342348
**************************************************/
343-
static void mlk_matvec_mul(mlk_polyvec out, const mlk_polymat a,
344-
const mlk_polyvec v, const mlk_polyvec_mulcache vc)
349+
static void mlk_matvec_mul(mlk_polyvec *out, const mlk_polymat *a,
350+
const mlk_polyvec *v, const mlk_polyvec_mulcache *vc)
345351
__contract__(
346352
requires(memory_no_alias(out, sizeof(mlk_polyvec)))
347353
requires(memory_no_alias(a, sizeof(mlk_polymat)))
348354
requires(memory_no_alias(v, sizeof(mlk_polyvec)))
349355
requires(memory_no_alias(vc, sizeof(mlk_polyvec_mulcache)))
350-
requires(forall(k0, 0, MLKEM_K * MLKEM_K,
351-
array_bound(a[k0].coeffs, 0, MLKEM_N, 0, MLKEM_UINT12_LIMIT)))
356+
requires(forall(k0, 0, MLKEM_K,
357+
forall(k1, 0, MLKEM_K,
358+
array_bound(a->vec[k0].vec[k1].coeffs, 0, MLKEM_N, 0, MLKEM_UINT12_LIMIT))))
352359
requires(forall(k1, 0, MLKEM_K,
353360
array_abs_bound(v[k1].coeffs, 0, MLKEM_N, MLK_NTT_BOUND)))
354361
requires(forall(k2, 0, MLKEM_K,
@@ -357,32 +364,17 @@ __contract__(
357364
ensures(forall(k3, 0, MLKEM_K,
358365
array_abs_bound(out[k3].coeffs, 0, MLKEM_N, INT16_MAX/2))))
359366
{
360-
/* Temporary on the "refine-bounds" branch - unroll to a simple
361-
* sequence of calls for each possible value of MLKEM_K to
362-
* simplify proof.
363-
*/
364-
mlk_polyvec_basemul_acc_montgomery_cached(&out[0], &a[0], v, vc);
365-
mlk_polyvec_basemul_acc_montgomery_cached(&out[1], &a[MLKEM_K], v, vc);
366-
367-
#if MLKEM_K == 3
368-
mlk_polyvec_basemul_acc_montgomery_cached(&out[2], &a[MLKEM_K * 2], v, vc);
369-
#elif MLKEM_K == 4
370-
mlk_polyvec_basemul_acc_montgomery_cached(&out[2], &a[MLKEM_K * 2], v, vc);
371-
mlk_polyvec_basemul_acc_montgomery_cached(&out[3], &a[MLKEM_K * 3], v, vc);
372-
#endif
373-
374-
/* unsigned i;
375-
* for (i = 0; i < MLKEM_K; i++)
376-
* __loop__(
377-
* assigns(i, object_whole(out))
378-
* invariant(i <= MLKEM_K)
379-
* invariant(forall(k, 0, i,
380-
* array_abs_bound(out[k].coeffs, 0, MLKEM_N, INT16_MAX/2))))
381-
* {
382-
* mlk_polyvec_basemul_acc_montgomery_cached(&out[i], &a[MLKEM_K * i], v,
383-
* vc);
384-
* }
385-
*/
367+
unsigned i;
368+
for (i = 0; i < MLKEM_K; i++)
369+
__loop__(
370+
assigns(i, object_whole(out))
371+
invariant(i <= MLKEM_K))
372+
invariant(forall(k, 0, i,
373+
array_abs_bound(out[k].coeffs, 0, MLKEM_N, INT16_MAX/2))))
374+
{
375+
mlk_polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a->vec[i], v,
376+
vc);
377+
}
386378
}
387379

388380
/* Reference: `indcpa_keypair_derand()` in the reference implementation @[REF].
@@ -433,7 +425,7 @@ int mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
433425
*/
434426
MLK_CT_TESTING_DECLASSIFY(publicseed, MLKEM_SYMBYTES);
435427

436-
mlk_gen_matrix(a, publicseed, 0 /* no transpose */);
428+
mlk_gen_matrix(&a, publicseed, 0 /* no transpose */);
437429

438430
#if MLKEM_K == 2
439431
mlk_poly_getnoise_eta1_4x(&skpv->vec[0], &skpv->vec[1], &e->vec[0],
@@ -457,19 +449,19 @@ int mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
457449
noiseseed, 4, 5, 6, 7);
458450
#endif /* MLKEM_K == 4 */
459451

460-
mlk_polyvec_ntt(skpv);
461-
mlk_polyvec_ntt(e);
452+
mlk_polyvec_ntt(&skpv);
453+
mlk_polyvec_ntt(&e);
462454

463-
mlk_polyvec_mulcache_compute(skpv_cache, skpv);
464-
mlk_matvec_mul(pkpv, a, skpv, skpv_cache);
465-
mlk_polyvec_tomont(pkpv);
455+
mlk_polyvec_mulcache_compute(&skpv_cache, &skpv);
456+
mlk_matvec_mul(&pkpv, &a, &skpv, &skpv_cache);
457+
mlk_polyvec_tomont(&pkpv);
466458

467-
mlk_polyvec_add(pkpv, e);
468-
mlk_polyvec_reduce(pkpv);
469-
mlk_polyvec_reduce(skpv);
459+
mlk_polyvec_add(&pkpv, &e);
460+
mlk_polyvec_reduce(&pkpv);
461+
mlk_polyvec_reduce(&skpv);
470462

471-
mlk_pack_sk(sk, skpv);
472-
mlk_pack_pk(pk, pkpv, publicseed);
463+
mlk_pack_sk(sk, &skpv);
464+
mlk_pack_pk(pk, &pkpv, publicseed);
473465

474466
cleanup:
475467
/* Specification: Partially implements
@@ -528,7 +520,7 @@ int mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
528520
*/
529521
MLK_CT_TESTING_DECLASSIFY(seed, MLKEM_SYMBYTES);
530522

531-
mlk_gen_matrix(at, seed, 1 /* transpose */);
523+
mlk_gen_matrix(&at, seed, 1 /* transpose */);
532524

533525
#if MLKEM_K == 2
534526
mlk_poly_getnoise_eta1122_4x(&sp->vec[0], &sp->vec[1], &ep->vec[0],
@@ -552,7 +544,7 @@ int mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
552544
mlk_poly_getnoise_eta2(epp, coins, 8);
553545
#endif /* MLKEM_K == 4 */
554546

555-
mlk_polyvec_ntt(sp);
547+
mlk_polyvec_ntt(&sp);
556548

557549
mlk_polyvec_mulcache_compute(sp_cache, sp);
558550
mlk_matvec_mul(b, at, sp, sp_cache);

mlkem/src/indcpa.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
*
4040
**************************************************/
4141
MLK_INTERNAL_API
42-
void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
42+
void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
4343
int transposed)
4444
__contract__(
4545
requires(memory_no_alias(a, sizeof(mlk_polymat)))

mlkem/src/kem.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ int crypto_kem_check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES])
5353
goto cleanup;
5454
}
5555

56-
mlk_polyvec_frombytes(p, pk);
57-
mlk_polyvec_reduce(p);
58-
mlk_polyvec_tobytes(p_reencoded, p);
56+
mlk_polyvec_frombytes(&p, pk);
57+
mlk_polyvec_reduce(&p);
58+
mlk_polyvec_tobytes(p_reencoded, &p);
5959

6060
/* We use a constant-time memcmp here to avoid having to
6161
* declassify the PK before the PCT has succeeded. */

0 commit comments

Comments
 (0)