Skip to content

Commit 6539a79

Browse files
committed
x86_64: Replace rej_uniform intrinsics with assembly
Replace the AVX2 intrinsics implementation of rej_uniform with hand-written x86_64 assembly, resolving #926. The assembly follows the same algorithmic structure as the intrinsics version: load 32 bytes, vpermq to rearrange 64-bit lanes, vpshufb to extract 8x 3-byte groups, mask to 23 bits, compare against MLDSA_Q, then use the lookup table to compact valid coefficients. Key design decisions: - Table passed as parameter (consistent with aarch64 approach), avoiding external symbol references for simpasm compatibility - All constants constructed from immediates (no .rodata section), enabling future HOL-Light formal verification - Register name #defines with #undef cleanup for SCU builds - CBMC contract on assembly function declaration (following mlkem-native) - vzeroupper at function exit to avoid AVX-SSE transition penalties Also adds poly_uniform to the component benchmark. Signed-off-by: jakemas <jakemas@amazon.com>
1 parent 9ee2f35 commit 6539a79

17 files changed

Lines changed: 788 additions & 260 deletions

BIBLIOGRAPHY.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ source code and documentation.
227227
- [dev/x86_64/src/poly_use_hint_88_avx2.c](dev/x86_64/src/poly_use_hint_88_avx2.c)
228228
- [dev/x86_64/src/polyz_unpack_17_avx2.c](dev/x86_64/src/polyz_unpack_17_avx2.c)
229229
- [dev/x86_64/src/polyz_unpack_19_avx2.c](dev/x86_64/src/polyz_unpack_19_avx2.c)
230-
- [dev/x86_64/src/rej_uniform_avx2.c](dev/x86_64/src/rej_uniform_avx2.c)
230+
- [dev/x86_64/src/rej_uniform_avx2.S](dev/x86_64/src/rej_uniform_avx2.S)
231231
- [dev/x86_64/src/rej_uniform_eta2_avx2.c](dev/x86_64/src/rej_uniform_eta2_avx2.c)
232232
- [dev/x86_64/src/rej_uniform_eta4_avx2.c](dev/x86_64/src/rej_uniform_eta4_avx2.c)
233233
- [mldsa/src/native/x86_64/src/intt.S](mldsa/src/native/x86_64/src/intt.S)
@@ -245,11 +245,12 @@ source code and documentation.
245245
- [mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c](mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c)
246246
- [mldsa/src/native/x86_64/src/polyz_unpack_17_avx2.c](mldsa/src/native/x86_64/src/polyz_unpack_17_avx2.c)
247247
- [mldsa/src/native/x86_64/src/polyz_unpack_19_avx2.c](mldsa/src/native/x86_64/src/polyz_unpack_19_avx2.c)
248-
- [mldsa/src/native/x86_64/src/rej_uniform_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_avx2.c)
248+
- [mldsa/src/native/x86_64/src/rej_uniform_avx2.S](mldsa/src/native/x86_64/src/rej_uniform_avx2.S)
249249
- [mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c)
250250
- [mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c)
251251
- [proofs/hol_light/x86_64/mldsa/mldsa_intt.S](proofs/hol_light/x86_64/mldsa/mldsa_intt.S)
252252
- [proofs/hol_light/x86_64/mldsa/mldsa_ntt.S](proofs/hol_light/x86_64/mldsa/mldsa_ntt.S)
253+
- [proofs/hol_light/x86_64/mldsa/mldsa_rej_uniform.S](proofs/hol_light/x86_64/mldsa/mldsa_rej_uniform.S)
253254

254255
### `Round3_Spec`
255256

dev/x86_64/meta.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ static MLD_INLINE int mld_rej_uniform_native(int32_t *r, unsigned len,
8080
}
8181

8282
/* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */
83-
return (int)mld_rej_uniform_avx2(r, buf);
83+
return (int)mld_rej_uniform_avx2(r, buf,
84+
(const uint8_t *)mld_rej_uniform_table);
8485
}
8586

8687
#if defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLDSA_ETA == 2

dev/x86_64/src/arith_native_x86_64.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ void mld_nttunpack_avx2(int32_t *r);
6565
#define mld_rej_uniform_avx2 MLD_NAMESPACE(mld_rej_uniform_avx2)
6666
MLD_MUST_CHECK_RETURN_VALUE
6767
unsigned mld_rej_uniform_avx2(int32_t *r,
68-
const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN]);
68+
const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN],
69+
const uint8_t *table);
6970

7071
#define mld_rej_uniform_eta2_avx2 MLD_NAMESPACE(mld_rej_uniform_eta2_avx2)
7172
MLD_MUST_CHECK_RETURN_VALUE

dev/x86_64/src/rej_uniform_avx2.S

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
/* References
7+
* ==========
8+
*
9+
* - [REF_AVX2]
10+
* CRYSTALS-Dilithium optimized AVX2 implementation
11+
* Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé
12+
* https://github.com/pq-crystals/dilithium/tree/master/avx2
13+
*/
14+
15+
/*
16+
* This file is derived from the public domain
17+
* AVX2 Dilithium implementation @[REF_AVX2].
18+
*/
19+
20+
#include "../../../common.h"
21+
#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \
22+
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
23+
/* simpasm: header-end */
24+
25+
#define out %rdi
26+
#define in %rsi
27+
#define tab %rdx
28+
29+
#define ctr %eax
30+
#define pos %ecx
31+
32+
#define good %r8d
33+
#define cnt %r9d
34+
#define tmp %r10
35+
36+
#define idx8 %ymm0
37+
#define mask %ymm1
38+
#define bound %ymm2
39+
#define data %ymm3
40+
#define cmp_result %ymm4
41+
42+
.text
43+
44+
/*
45+
* unsigned mld_rej_uniform_avx2(int32_t *r, const uint8_t *buf,
46+
* const uint8_t *table)
47+
*
48+
* Rejection sampling of uniform polynomial coefficients.
49+
* Extracts 23-bit values from a byte buffer and accepts those < MLDSA_Q.
50+
*
51+
* Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t)
52+
* in (rsi): pointer to input byte buffer buf (840 bytes)
53+
* tab (rdx): pointer to rejection sampling lookup table (256x8 bytes)
54+
*
55+
* Returns: ctr (eax): number of valid coefficients written to r
56+
*/
57+
.balign 4
58+
.global MLD_ASM_NAMESPACE(mld_rej_uniform_avx2)
59+
MLD_ASM_FN_SYMBOL(mld_rej_uniform_avx2)
60+
61+
/*
62+
* Construct the shuffle mask for extracting 8 x 23-bit values from 24 bytes.
63+
*
64+
* After vpermq with 0x94, the 32 loaded bytes are rearranged as:
65+
* Low 128 bits: bytes [0..15] (original 64-bit lanes 0, 1)
66+
* High 128 bits: bytes [8..23] (original 64-bit lanes 1, 2)
67+
*
68+
* vpshufb then picks 3-byte groups and zero-pads each to a 32-bit lane:
69+
* Low half: {0,1,2,FF, 3,4,5,FF, 6,7,8,FF, 9,10,11,FF}
70+
* High half: {4,5,6,FF, 7,8,9,FF, 10,11,12,FF, 13,14,15,FF}
71+
*
72+
* This extracts 8 non-overlapping 3-byte windows from the first 24 input bytes.
73+
*/
74+
movq $0xFF050403FF020100, tmp
75+
vmovq tmp, %xmm0
76+
movq $0xFF0B0A09FF080706, tmp
77+
vpinsrq $1, tmp, %xmm0, %xmm0
78+
movq $0xFF090807FF060504, tmp
79+
vmovq tmp, %xmm3
80+
movq $0xFF0F0E0DFF0C0B0A, tmp
81+
vpinsrq $1, tmp, %xmm3, %xmm3
82+
vinserti128 $1, %xmm3, idx8, idx8
83+
84+
// Construct broadcast constants
85+
movl $0x7FFFFF, good
86+
vmovd good, %xmm1
87+
vpbroadcastd %xmm1, mask // mask: 23-bit extraction
88+
89+
movl $8380417, good // MLDSA_Q
90+
vmovd good, %xmm2
91+
vpbroadcastd %xmm2, bound // bound: rejection threshold
92+
93+
// Initialize counters
94+
xorl ctr, ctr // ctr = 0
95+
xorl pos, pos // pos = 0
96+
97+
/*
98+
* Main SIMD loop: process 24 input bytes into up to 8 coefficients
99+
* per iteration. Loops while ctr <= MLDSA_N - 8 and pos <= BUFLEN - 32.
100+
*/
101+
mld_rej_uniform_avx2_loop:
102+
cmpl $248, ctr // MLDSA_N - 8
103+
ja mld_rej_uniform_avx2_scalar
104+
cmpl $808, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 32
105+
ja mld_rej_uniform_avx2_scalar
106+
107+
vmovdqu (in, %rcx), data // load 32 bytes from buf[pos]
108+
addl $24, pos // advance pos
109+
vpermq $0x94, data, data // rearrange 64-bit lanes: [2,1,1,0]
110+
vpshufb idx8, data, data // extract 8 x 3-byte groups
111+
vpand mask, data, data // mask to 23 bits
112+
113+
vpsubd bound, data, cmp_result // d - Q: negative if d < Q (valid)
114+
vmovmskps cmp_result, good // extract sign bits as 8-bit mask
115+
116+
popcntl good, cnt // count valid coefficients
117+
118+
vmovq (tab, %r8, 8), %xmm4 // load permutation from table[good]
119+
vpmovzxbd %xmm4, cmp_result // zero-extend to 8 dword indices
120+
vpermd data, cmp_result, data // compact valid coefficients to front
121+
122+
vmovdqu data, (out, %rax, 4) // store at r[ctr]
123+
addl cnt, ctr // ctr += popcount(good)
124+
125+
jmp mld_rej_uniform_avx2_loop
126+
127+
/*
128+
* Scalar fallback loop: process remaining bytes one coefficient at a time.
129+
* Loops while ctr < MLDSA_N and pos <= BUFLEN - 3.
130+
*/
131+
mld_rej_uniform_avx2_scalar:
132+
cmpl $256, ctr // MLDSA_N
133+
jae mld_rej_uniform_avx2_done
134+
cmpl $837, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 3
135+
ja mld_rej_uniform_avx2_done
136+
137+
movzwl (in, %rcx), good // load 2 bytes at buf[pos]
138+
movzbl 2(in, %rcx), cnt // load third byte
139+
shll $16, cnt
140+
orl cnt, good
141+
andl $0x7FFFFF, good // mask to 23 bits
142+
addl $3, pos // advance pos
143+
144+
cmpl $8380417, good // MLDSA_Q
145+
jae mld_rej_uniform_avx2_scalar // reject if >= Q
146+
147+
movl good, (out, %rax, 4) // store valid coefficient
148+
addl $1, ctr // ctr++
149+
jmp mld_rej_uniform_avx2_scalar
150+
151+
mld_rej_uniform_avx2_done:
152+
vzeroupper
153+
ret
154+
155+
/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
156+
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
157+
#undef out
158+
#undef in
159+
#undef tab
160+
#undef ctr
161+
#undef pos
162+
#undef good
163+
#undef cnt
164+
#undef tmp
165+
#undef idx8
166+
#undef mask
167+
#undef bound
168+
#undef data
169+
#undef cmp_result
170+
171+
/* simpasm: footer-start */
172+
#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \
173+
*/

dev/x86_64/src/rej_uniform_avx2.c

Lines changed: 0 additions & 126 deletions
This file was deleted.

mldsa/mldsa_native.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@
9090
#include "src/native/x86_64/src/poly_use_hint_88_avx2.c"
9191
#include "src/native/x86_64/src/polyz_unpack_17_avx2.c"
9292
#include "src/native/x86_64/src/polyz_unpack_19_avx2.c"
93-
#include "src/native/x86_64/src/rej_uniform_avx2.c"
9493
#include "src/native/x86_64/src/rej_uniform_eta2_avx2.c"
9594
#include "src/native/x86_64/src/rej_uniform_eta4_avx2.c"
9695
#include "src/native/x86_64/src/rej_uniform_table.c"

mldsa/mldsa_native_asm.S

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
#include "src/native/x86_64/src/pointwise_acc_l4.S"
8787
#include "src/native/x86_64/src/pointwise_acc_l5.S"
8888
#include "src/native/x86_64/src/pointwise_acc_l7.S"
89+
#include "src/native/x86_64/src/rej_uniform_avx2.S"
8990
#endif /* MLD_SYS_X86_64 */
9091
#endif /* MLD_CONFIG_USE_NATIVE_BACKEND_ARITH */
9192

mldsa/src/native/x86_64/meta.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ static MLD_INLINE int mld_rej_uniform_native(int32_t *r, unsigned len,
8080
}
8181

8282
/* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */
83-
return (int)mld_rej_uniform_avx2(r, buf);
83+
return (int)mld_rej_uniform_avx2(r, buf,
84+
(const uint8_t *)mld_rej_uniform_table);
8485
}
8586

8687
#if defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLDSA_ETA == 2

mldsa/src/native/x86_64/src/arith_native_x86_64.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ void mld_nttunpack_avx2(int32_t *r);
6565
#define mld_rej_uniform_avx2 MLD_NAMESPACE(mld_rej_uniform_avx2)
6666
MLD_MUST_CHECK_RETURN_VALUE
6767
unsigned mld_rej_uniform_avx2(int32_t *r,
68-
const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN]);
68+
const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN],
69+
const uint8_t *table);
6970

7071
#define mld_rej_uniform_eta2_avx2 MLD_NAMESPACE(mld_rej_uniform_eta2_avx2)
7172
MLD_MUST_CHECK_RETURN_VALUE

0 commit comments

Comments
 (0)