Skip to content

Commit 55f0028

Browse files
jakemasclaude
andcommitted
x86_64: Replace rej_uniform intrinsics with assembly + HOL Light proof
Replaces the x86_64 AVX2 rej_uniform implementation (previously written in C with intrinsics) with a hand-written assembly routine, and adds a functional correctness proof in HOL Light on top of the s2n-bignum infrastructure. Highlights: - dev/x86_64/src/rej_uniform_avx2_asm.S and mldsa/src/native/x86_64/src/rej_uniform_avx2_asm.S: new .S file exposing mld_rej_uniform_avx2_asm (replaces the intrinsics-based rej_uniform_avx2.c). - proofs/hol_light/x86_64/mldsa/rej_uniform_avx2_asm.S and proofs/hol_light/x86_64/proofs/rej_uniform_avx2_asm.ml: HOL Light proof of MLDSA_REJ_UNIFORM_{,NOIBT_}SUBROUTINE_CORRECT, with no remaining CHEATs. - proofs/cbmc/rej_uniform_native_x86_64/: CBMC contract proof (249/249 passing). - CI: hol_light.yml and Makefile updated for the new bytecode dump and autogen instruction-decode format; s2n-bignum pin bumped to include the supporting tactics. Naming follows the asm-suffix convention introduced on main (eada109 / e810d00): symbol mld_rej_uniform_avx2_asm, label prefix rej_uniform_avx2_asm_. Signed-off-by: Jake Massimo <jakemas@amazon.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0423f1d commit 55f0028

21 files changed

Lines changed: 4820 additions & 267 deletions

.github/workflows/hol_light.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ jobs:
203203
needs: ["mldsa_specs.ml", "mldsa_utils.ml", "mldsa_zetas.ml"]
204204
- name: intt_avx2_asm
205205
needs: ["mldsa_specs.ml", "mldsa_utils.ml", "mldsa_zetas.ml"]
206+
- name: rej_uniform_avx2_asm
207+
needs: ["mldsa_rej_uniform_table.ml"]
206208
- name: nttunpack_avx2_asm
207209
needs: ["mldsa_specs.ml"]
208210
- name: pointwise_avx2_asm

BIBLIOGRAPHY.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ source code and documentation.
234234
- [dev/x86_64/src/poly_use_hint_88_avx2.c](dev/x86_64/src/poly_use_hint_88_avx2.c)
235235
- [dev/x86_64/src/polyz_unpack_17_avx2.c](dev/x86_64/src/polyz_unpack_17_avx2.c)
236236
- [dev/x86_64/src/polyz_unpack_19_avx2.c](dev/x86_64/src/polyz_unpack_19_avx2.c)
237-
- [dev/x86_64/src/rej_uniform_avx2.c](dev/x86_64/src/rej_uniform_avx2.c)
237+
- [dev/x86_64/src/rej_uniform_avx2_asm.S](dev/x86_64/src/rej_uniform_avx2_asm.S)
238238
- [dev/x86_64/src/rej_uniform_eta2_avx2.c](dev/x86_64/src/rej_uniform_eta2_avx2.c)
239239
- [dev/x86_64/src/rej_uniform_eta4_avx2.c](dev/x86_64/src/rej_uniform_eta4_avx2.c)
240240
- [mldsa/src/native/x86_64/src/intt_avx2_asm.S](mldsa/src/native/x86_64/src/intt_avx2_asm.S)
@@ -252,7 +252,7 @@ source code and documentation.
252252
- [mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c](mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c)
253253
- [mldsa/src/native/x86_64/src/polyz_unpack_17_avx2.c](mldsa/src/native/x86_64/src/polyz_unpack_17_avx2.c)
254254
- [mldsa/src/native/x86_64/src/polyz_unpack_19_avx2.c](mldsa/src/native/x86_64/src/polyz_unpack_19_avx2.c)
255-
- [mldsa/src/native/x86_64/src/rej_uniform_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_avx2.c)
255+
- [mldsa/src/native/x86_64/src/rej_uniform_avx2_asm.S](mldsa/src/native/x86_64/src/rej_uniform_avx2_asm.S)
256256
- [mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c)
257257
- [mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c)
258258
- [proofs/hol_light/x86_64/mldsa/intt_avx2_asm.S](proofs/hol_light/x86_64/mldsa/intt_avx2_asm.S)
@@ -262,6 +262,7 @@ source code and documentation.
262262
- [proofs/hol_light/x86_64/mldsa/pointwise_acc_l5_avx2_asm.S](proofs/hol_light/x86_64/mldsa/pointwise_acc_l5_avx2_asm.S)
263263
- [proofs/hol_light/x86_64/mldsa/pointwise_acc_l7_avx2_asm.S](proofs/hol_light/x86_64/mldsa/pointwise_acc_l7_avx2_asm.S)
264264
- [proofs/hol_light/x86_64/mldsa/pointwise_avx2_asm.S](proofs/hol_light/x86_64/mldsa/pointwise_avx2_asm.S)
265+
- [proofs/hol_light/x86_64/mldsa/rej_uniform_avx2_asm.S](proofs/hol_light/x86_64/mldsa/rej_uniform_avx2_asm.S)
265266

266267
### `Round3_Spec`
267268

dev/x86_64/meta.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ static MLD_INLINE int mld_rej_uniform_native(int32_t *r, unsigned len,
8080
}
8181

8282
/* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */
83-
return (int)mld_rej_uniform_avx2(r, buf);
83+
return (int)mld_rej_uniform_avx2_asm(r, buf,
84+
(const uint8_t *)mld_rej_uniform_table);
8485
}
8586

8687
#if !defined(MLD_CONFIG_NO_KEYPAIR_API)

dev/x86_64/src/arith_native_x86_64.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,21 @@ __contract__(
7373
r[i] == old(*(int32_t (*)[MLDSA_N])r)[j])))
7474
);
7575

76-
#define mld_rej_uniform_avx2 MLD_NAMESPACE(mld_rej_uniform_avx2)
76+
#define mld_rej_uniform_avx2_asm MLD_NAMESPACE(rej_uniform_avx2_asm)
77+
/* This must be kept in sync with the HOL-Light specification
78+
* in proofs/hol_light/x86_64/proofs/rej_uniform_avx2_asm.ml */
7779
MLD_MUST_CHECK_RETURN_VALUE
78-
unsigned mld_rej_uniform_avx2(int32_t *r,
79-
const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN]);
80+
unsigned mld_rej_uniform_avx2_asm(
81+
int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN],
82+
const uint8_t *table)
83+
__contract__(
84+
requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N))
85+
requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_BUFLEN))
86+
requires(memory_no_alias(table, 256 * sizeof(uint64_t)))
87+
assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N))
88+
ensures(return_value <= MLDSA_N)
89+
ensures(array_bound(r, 0, return_value, 0, MLDSA_Q))
90+
);
8091

8192
#if !defined(MLD_CONFIG_NO_KEYPAIR_API)
8293
#define mld_rej_uniform_eta2_avx2 MLD_NAMESPACE(mld_rej_uniform_eta2_avx2)

dev/x86_64/src/rej_uniform_avx2.c

Lines changed: 0 additions & 126 deletions
This file was deleted.
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
/* References
7+
* ==========
8+
*
9+
* - [REF_AVX2]
10+
* CRYSTALS-Dilithium optimized AVX2 implementation
11+
* Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé
12+
* https://github.com/pq-crystals/dilithium/tree/master/avx2
13+
*/
14+
15+
/*
16+
* This file is derived from the public domain
17+
* AVX2 Dilithium implementation @[REF_AVX2].
18+
*/
19+
20+
#include "../../../common.h"
21+
#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \
22+
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
23+
/* simpasm: header-end */
24+
25+
#define out %rdi
26+
#define in %rsi
27+
#define tab %rdx
28+
29+
#define ctr %eax
30+
#define pos %ecx
31+
32+
#define good %r8d
33+
#define cnt %r9d
34+
#define tmp %r10
35+
36+
#define idx8 %ymm0
37+
#define mask %ymm1
38+
#define bound %ymm2
39+
#define data %ymm3
40+
#define cmp_result %ymm4
41+
42+
.text
43+
44+
/*
45+
* unsigned mld_rej_uniform_avx2_asm(int32_t *r, const uint8_t *buf,
46+
* const uint8_t *table)
47+
*
48+
* Rejection sampling of uniform polynomial coefficients.
49+
* Extracts 23-bit values from a byte buffer and accepts those < MLDSA_Q.
50+
*
51+
* Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t)
52+
* in (rsi): pointer to input byte buffer buf (840 bytes)
53+
* tab (rdx): pointer to rejection sampling lookup table (256x8 bytes)
54+
*
55+
* Returns: ctr (eax): number of valid coefficients written to r
56+
*/
57+
.balign 4
58+
.global MLD_ASM_NAMESPACE(rej_uniform_avx2_asm)
59+
MLD_ASM_FN_SYMBOL(rej_uniform_avx2_asm)
60+
61+
/*
62+
* Construct the shuffle mask for extracting 8 x 23-bit values from 24 bytes.
63+
*
64+
* After vpermq with 0x94, the 32 loaded bytes are rearranged as:
65+
* Low 128 bits: bytes [0..15] (original 64-bit lanes 0, 1)
66+
* High 128 bits: bytes [8..23] (original 64-bit lanes 1, 2)
67+
*
68+
* vpshufb then picks 3-byte groups and zero-pads each to a 32-bit lane:
69+
* Low half: {0,1,2,FF, 3,4,5,FF, 6,7,8,FF, 9,10,11,FF}
70+
* High half: {4,5,6,FF, 7,8,9,FF, 10,11,12,FF, 13,14,15,FF}
71+
*
72+
* This extracts 8 non-overlapping 3-byte windows from the first 24 input bytes.
73+
*/
74+
movq $0xFF050403FF020100, tmp
75+
vmovq tmp, %xmm0
76+
movq $0xFF0B0A09FF080706, tmp
77+
vpinsrq $1, tmp, %xmm0, %xmm0
78+
movq $0xFF090807FF060504, tmp
79+
vmovq tmp, %xmm3
80+
movq $0xFF0F0E0DFF0C0B0A, tmp
81+
vpinsrq $1, tmp, %xmm3, %xmm3
82+
vinserti128 $1, %xmm3, idx8, idx8
83+
84+
// Construct broadcast constants
85+
movl $0x7FFFFF, good
86+
vmovd good, %xmm1
87+
vpbroadcastd %xmm1, mask // mask: 23-bit extraction
88+
89+
movl $8380417, good // MLDSA_Q
90+
vmovd good, %xmm2
91+
vpbroadcastd %xmm2, bound // bound: rejection threshold
92+
93+
// Initialize counters
94+
xorl ctr, ctr // ctr = 0
95+
xorl pos, pos // pos = 0
96+
97+
/*
98+
* Main SIMD loop: process 24 input bytes into up to 8 coefficients
99+
* per iteration. Loops while ctr <= MLDSA_N - 8 and pos <= BUFLEN - 32.
100+
*/
101+
rej_uniform_avx2_asm_loop:
102+
cmpl $248, ctr // MLDSA_N - 8
103+
ja rej_uniform_avx2_asm_scalar
104+
cmpl $808, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 32
105+
ja rej_uniform_avx2_asm_scalar
106+
107+
vmovdqu (in, %rcx), data // load 32 bytes from buf[pos]
108+
addl $24, pos // advance pos
109+
vpermq $0x94, data, data // rearrange 64-bit lanes: [2,1,1,0]
110+
vpshufb idx8, data, data // extract 8 x 3-byte groups
111+
vpand mask, data, data // mask to 23 bits
112+
113+
vpsubd bound, data, cmp_result // d - Q: negative if d < Q (valid)
114+
vmovmskps cmp_result, good // extract sign bits as 8-bit mask
115+
116+
popcntl good, cnt // count valid coefficients
117+
118+
vmovq (tab, %r8, 8), %xmm4 // load permutation from table[good]
119+
vpmovzxbd %xmm4, cmp_result // zero-extend to 8 dword indices
120+
vpermd data, cmp_result, data // compact valid coefficients to front
121+
122+
vmovdqu data, (out, %rax, 4) // store at r[ctr]
123+
addl cnt, ctr // ctr += popcount(good)
124+
125+
jmp rej_uniform_avx2_asm_loop
126+
127+
/*
128+
* Scalar fallback loop: process remaining bytes one coefficient at a time.
129+
* Loops while ctr < MLDSA_N and pos <= BUFLEN - 3.
130+
*/
131+
rej_uniform_avx2_asm_scalar:
132+
cmpl $256, ctr // MLDSA_N
133+
jae rej_uniform_avx2_asm_done
134+
cmpl $837, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 3
135+
ja rej_uniform_avx2_asm_done
136+
137+
movzwl (in, %rcx), good // load 2 bytes at buf[pos]
138+
movzbl 2(in, %rcx), cnt // load third byte
139+
shll $16, cnt
140+
orl cnt, good
141+
andl $0x7FFFFF, good // mask to 23 bits
142+
addl $3, pos // advance pos
143+
144+
cmpl $8380417, good // MLDSA_Q
145+
jae rej_uniform_avx2_asm_scalar // reject if >= Q
146+
147+
movl good, (out, %rax, 4) // store valid coefficient
148+
addl $1, ctr // ctr++
149+
jmp rej_uniform_avx2_asm_scalar
150+
151+
rej_uniform_avx2_asm_done:
152+
vzeroupper
153+
ret
154+
155+
/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
156+
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
157+
#undef out
158+
#undef in
159+
#undef tab
160+
#undef ctr
161+
#undef pos
162+
#undef good
163+
#undef cnt
164+
#undef tmp
165+
#undef idx8
166+
#undef mask
167+
#undef bound
168+
#undef data
169+
#undef cmp_result
170+
171+
/* simpasm: footer-start */
172+
#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \
173+
*/

mldsa/mldsa_native.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@
9090
#include "src/native/x86_64/src/poly_use_hint_88_avx2.c"
9191
#include "src/native/x86_64/src/polyz_unpack_17_avx2.c"
9292
#include "src/native/x86_64/src/polyz_unpack_19_avx2.c"
93-
#include "src/native/x86_64/src/rej_uniform_avx2.c"
9493
#include "src/native/x86_64/src/rej_uniform_eta2_avx2.c"
9594
#include "src/native/x86_64/src/rej_uniform_eta4_avx2.c"
9695
#include "src/native/x86_64/src/rej_uniform_table.c"
@@ -758,7 +757,7 @@
758757
#undef mld_poly_use_hint_88_avx2
759758
#undef mld_polyz_unpack_17_avx2
760759
#undef mld_polyz_unpack_19_avx2
761-
#undef mld_rej_uniform_avx2
760+
#undef mld_rej_uniform_avx2_asm
762761
#undef mld_rej_uniform_eta2_avx2
763762
#undef mld_rej_uniform_eta4_avx2
764763
#undef mld_rej_uniform_table

0 commit comments

Comments
 (0)