|
| 1 | +/* |
| 2 | + * Copyright (c) The mldsa-native project authors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT |
| 4 | + */ |
| 5 | + |
| 6 | +/* References |
| 7 | + * ========== |
| 8 | + * |
| 9 | + * - [REF_AVX2] |
| 10 | + * CRYSTALS-Dilithium optimized AVX2 implementation |
| 11 | + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé |
| 12 | + * https://github.com/pq-crystals/dilithium/tree/master/avx2 |
| 13 | + */ |
| 14 | + |
| 15 | +/* |
| 16 | + * This file is derived from the public domain |
| 17 | + * AVX2 Dilithium implementation @[REF_AVX2]. |
| 18 | + */ |
| 19 | + |
| 20 | +#include "../../../common.h" |
| 21 | +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ |
| 22 | + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) |
| 23 | +/* simpasm: header-end */ |
| 24 | + |
| 25 | +#define out %rdi |
| 26 | +#define in %rsi |
| 27 | +#define tab %rdx |
| 28 | + |
| 29 | +#define ctr %eax |
| 30 | +#define pos %ecx |
| 31 | + |
| 32 | +#define good %r8d |
| 33 | +#define cnt %r9d |
| 34 | +#define tmp %r10 |
| 35 | + |
| 36 | +#define idx8 %ymm0 |
| 37 | +#define mask %ymm1 |
| 38 | +#define bound %ymm2 |
| 39 | +#define data %ymm3 |
| 40 | +#define cmp_result %ymm4 |
| 41 | + |
| 42 | + .text |
| 43 | + |
| 44 | +/* |
| 45 | + * unsigned mld_rej_uniform_avx2(int32_t *r, const uint8_t *buf, |
| 46 | + * const uint8_t *table) |
| 47 | + * |
| 48 | + * Rejection sampling of uniform polynomial coefficients. |
| 49 | + * Extracts 23-bit values from a byte buffer and accepts those < MLDSA_Q. |
| 50 | + * |
| 51 | + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) |
| 52 | + * in (rsi): pointer to input byte buffer buf (840 bytes) |
| 53 | + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) |
| 54 | + * |
| 55 | + * Returns: ctr (eax): number of valid coefficients written to r |
| 56 | + */ |
| 57 | + .balign 4 |
| 58 | + .global MLD_ASM_NAMESPACE(mld_rej_uniform_avx2) |
| 59 | +MLD_ASM_FN_SYMBOL(mld_rej_uniform_avx2) |
| 60 | + |
| 61 | +/* |
| 62 | + * Construct the shuffle mask for extracting 8 x 23-bit values from 24 bytes. |
| 63 | + * |
| 64 | + * After vpermq with 0x94, the 32 loaded bytes are rearranged as: |
| 65 | + * Low 128 bits: bytes [0..15] (original 64-bit lanes 0, 1) |
| 66 | + * High 128 bits: bytes [8..23] (original 64-bit lanes 1, 2) |
| 67 | + * |
| 68 | + * vpshufb then picks 3-byte groups and zero-pads each to a 32-bit lane: |
| 69 | + * Low half: {0,1,2,FF, 3,4,5,FF, 6,7,8,FF, 9,10,11,FF} |
| 70 | + * High half: {4,5,6,FF, 7,8,9,FF, 10,11,12,FF, 13,14,15,FF} |
| 71 | + * |
| 72 | + * This extracts 8 non-overlapping 3-byte windows from the first 24 input bytes. |
| 73 | + */ |
| 74 | + movq $0xFF050403FF020100, tmp |
| 75 | + vmovq tmp, %xmm0 |
| 76 | + movq $0xFF0B0A09FF080706, tmp |
| 77 | + vpinsrq $1, tmp, %xmm0, %xmm0 |
| 78 | + movq $0xFF090807FF060504, tmp |
| 79 | + vmovq tmp, %xmm3 |
| 80 | + movq $0xFF0F0E0DFF0C0B0A, tmp |
| 81 | + vpinsrq $1, tmp, %xmm3, %xmm3 |
| 82 | + vinserti128 $1, %xmm3, idx8, idx8 |
| 83 | + |
| 84 | +// Construct broadcast constants |
| 85 | + movl $0x7FFFFF, good |
| 86 | + vmovd good, %xmm1 |
| 87 | + vpbroadcastd %xmm1, mask // mask: 23-bit extraction |
| 88 | + |
| 89 | + movl $8380417, good // MLDSA_Q |
| 90 | + vmovd good, %xmm2 |
| 91 | + vpbroadcastd %xmm2, bound // bound: rejection threshold |
| 92 | + |
| 93 | +// Initialize counters |
| 94 | + xorl ctr, ctr // ctr = 0 |
| 95 | + xorl pos, pos // pos = 0 |
| 96 | + |
| 97 | +/* |
| 98 | + * Main SIMD loop: process 24 input bytes into up to 8 coefficients |
| 99 | + * per iteration. Loops while ctr <= MLDSA_N - 8 and pos <= BUFLEN - 32. |
| 100 | + */ |
| 101 | +mld_rej_uniform_avx2_loop: |
| 102 | + cmpl $248, ctr // MLDSA_N - 8 |
| 103 | + ja mld_rej_uniform_avx2_scalar |
| 104 | + cmpl $808, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 32 |
| 105 | + ja mld_rej_uniform_avx2_scalar |
| 106 | + |
| 107 | + vmovdqu (in, %rcx), data // load 32 bytes from buf[pos] |
| 108 | + addl $24, pos // advance pos |
| 109 | + vpermq $0x94, data, data // rearrange 64-bit lanes: [2,1,1,0] |
| 110 | + vpshufb idx8, data, data // extract 8 x 3-byte groups |
| 111 | + vpand mask, data, data // mask to 23 bits |
| 112 | + |
| 113 | + vpsubd bound, data, cmp_result // d - Q: negative if d < Q (valid) |
| 114 | + vmovmskps cmp_result, good // extract sign bits as 8-bit mask |
| 115 | + |
| 116 | + popcntl good, cnt // count valid coefficients |
| 117 | + |
| 118 | + vmovq (tab, %r8, 8), %xmm4 // load permutation from table[good] |
| 119 | + vpmovzxbd %xmm4, cmp_result // zero-extend to 8 dword indices |
| 120 | + vpermd data, cmp_result, data // compact valid coefficients to front |
| 121 | + |
| 122 | + vmovdqu data, (out, %rax, 4) // store at r[ctr] |
| 123 | + addl cnt, ctr // ctr += popcount(good) |
| 124 | + |
| 125 | + jmp mld_rej_uniform_avx2_loop |
| 126 | + |
| 127 | +/* |
| 128 | + * Scalar fallback loop: process remaining bytes one coefficient at a time. |
| 129 | + * Loops while ctr < MLDSA_N and pos <= BUFLEN - 3. |
| 130 | + */ |
| 131 | +mld_rej_uniform_avx2_scalar: |
| 132 | + cmpl $256, ctr // MLDSA_N |
| 133 | + jae mld_rej_uniform_avx2_done |
| 134 | + cmpl $837, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 3 |
| 135 | + ja mld_rej_uniform_avx2_done |
| 136 | + |
| 137 | + movzwl (in, %rcx), good // load 2 bytes at buf[pos] |
| 138 | + movzbl 2(in, %rcx), cnt // load third byte |
| 139 | + shll $16, cnt |
| 140 | + orl cnt, good |
| 141 | + andl $0x7FFFFF, good // mask to 23 bits |
| 142 | + addl $3, pos // advance pos |
| 143 | + |
| 144 | + cmpl $8380417, good // MLDSA_Q |
| 145 | + jae mld_rej_uniform_avx2_scalar // reject if >= Q |
| 146 | + |
| 147 | + movl good, (out, %rax, 4) // store valid coefficient |
| 148 | + addl $1, ctr // ctr++ |
| 149 | + jmp mld_rej_uniform_avx2_scalar |
| 150 | + |
| 151 | +mld_rej_uniform_avx2_done: |
| 152 | + vzeroupper |
| 153 | + ret |
| 154 | + |
| 155 | +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. |
| 156 | + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ |
| 157 | +#undef out |
| 158 | +#undef in |
| 159 | +#undef tab |
| 160 | +#undef ctr |
| 161 | +#undef pos |
| 162 | +#undef good |
| 163 | +#undef cnt |
| 164 | +#undef tmp |
| 165 | +#undef idx8 |
| 166 | +#undef mask |
| 167 | +#undef bound |
| 168 | +#undef data |
| 169 | +#undef cmp_result |
| 170 | + |
| 171 | +/* simpasm: footer-start */ |
| 172 | +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ |
| 173 | + */ |
0 commit comments