2424 !defined(MLK_CONFIG_MULTILEVEL_NO_SHARED)
2525/* simpasm: header-end */
2626
27- #define in %rsi
28- #define out %rdi
29- #define len %rdx
30- #define tab %rcx
31-
32- #define cnt %rax
33- #define ecnt %eax
34- #define pos %r8
35-
36- #define good %r11
37- #define pext_mask %r9
38- #define table_idx %r10
39-
40- #define bound %xmm0
41- #define temp0 %xmm1
42- #define temp1 %xmm3
43- #define vals %xmm2
44- #define shuffle_out_mask %xmm3
45- #define shuffle_in_mask %xmm4
46- #define and_mask %xmm5
27+ #define MLK_IN %rsi
28+ #define MLK_OUT %rdi
29+ #define MLK_LEN %rdx
30+ #define MLK_TAB %rcx
31+
32+ #define MLK_CNT %rax
33+ #define MLK_ECNT %eax
34+ #define MLK_POS %r8
35+
36+ #define MLK_GOOD %r11
37+ #define MLK_PEXT_MASK %r9
38+ #define MLK_TABLE_IDX %r10
39+
40+ #define MLK_BOUND %xmm0
41+ #define MLK_TEMP0 %xmm1
42+ #define MLK_TEMP1 %xmm3
43+ #define MLK_VALS %xmm2
44+ #define MLK_SHUFFLE_OUT_MASK %xmm3
45+ #define MLK_SHUFFLE_IN_MASK %xmm4
46+ #define MLK_AND_MASK %xmm5
4747
4848// High level overview of the algorithm:
4949// For every 96 bits (12 bytes) of the input:
5050// 1. Split 96 bits into eight 12-bit integers where each integer
51- // occupies a corresponding 16-bit element of `vals ` xmm register,
52- // 2. Compute an 8-bit value `good ` such that
53- // good [i] = vals [i] < MLKEM_Q ? 1 : 0, for i in [0, 7],
54- // 3. Shuffle the elements in `vals ` such that all good elements
51+ // occupies a corresponding 16-bit element of `MLK_VALS ` xmm register,
52+ // 2. Compute an 8-bit value `MLK_GOOD ` such that
53+ // MLK_GOOD [i] = MLK_VALS [i] < MLKEM_Q ? 1 : 0, for i in [0, 7],
54+ // 3. Shuffle the elements in `MLK_VALS ` such that all MLK_GOOD elements
5555// are ordered consecutivelly, and store them.
5656//
5757// Notes:
58- // - We exit early if we find the required number of good values,
58+ // - We exit early if we find the required number of MLK_GOOD values,
5959// - We use the stack as a temporary storage and copy to the actual
6060// output buffer only in the end. This is because the algorithm
6161// can overwrite up to 14 bytes (we use 16B for alignment),
@@ -74,77 +74,82 @@ MLK_ASM_FN_SYMBOL(rej_uniform_asm)
7474 testq len, len
7575 jz rej_uniform_asm_end
7676
77+ // Return if input length is 0
78+ xorl ecnt, ecnt
79+ testq len, len
80+ jz rej_uniform_asm_end
81+
7782 // Broadcast MLKEM_Q (3329) to all 16-bit elements of bound reg.
7883 movq $0x0D010D010D010D01 , %rax
79- movq %rax , bound
80- pinsrq $1 , %rax , bound
84+ movq %rax , MLK_BOUND
85+ pinsrq $1 , %rax , MLK_BOUND
8186
82- // Broadcast 12-bit mask 0xFFF to all 16-bit elements of bound reg.
87+ // Broadcast 12-bit mask 0xFFF to all 16-bit elements of MLK_BOUND reg.
8388 movq $0x0FFF0FFF0FFF0FFF , %rax
84- movq %rax , and_mask
85- pinsrq $1 , %rax , and_mask
89+ movq %rax , MLK_AND_MASK
90+ pinsrq $1 , %rax , MLK_AND_MASK
8691
8792 // Load shuffle mask:
8893 // 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11.
8994 movq $0x0504040302010100 , %rax
90- movq %rax , shuffle_in_mask
95+ movq %rax , MLK_SHUFFLE_IN_MASK
9196 movq $0x0B0A0A0908070706 , %rax
92- pinsrq $1 , %rax , shuffle_in_mask
97+ pinsrq $1 , %rax , MLK_SHUFFLE_IN_MASK
9398
94- movq $0 , cnt // cnt counts the number of good values we've found.
95- movq $0 , pos // pos is the current position in the input buffer.
96- movq $0x5555 , pext_mask // 0x5555 mask to extract every second bit.
99+ movq $0 , MLK_CNT // MLK_CNT counts the number of MLK_GOOD values we've found.
100+ movq $0 , MLK_POS // MLK_POS is the current position in the input buffer.
101+ movq $0x5555 , MLK_PEXT_MASK // 0x5555 mask to extract every second bit.
97102
98103rej_uniform_asm_loop_start:
99104 // 1. Split 96 bits into eight 12-bit integers where each integer.
100- // We explain the algorithm by considering the lowest 64 bits of vals .
101- movdqu (in , pos ), vals
102- // vals : [ 63..48 | 47..32 | 31..16 | 15..0 ]
103- pshufb shuffle_in_mask, vals
104- // vals : [ 47..32 | 39..24 | 23..8 | 15..0 ]
105- movdqa vals, temp1
105+ // We explain the algorithm by considering the lowest 64 bits of MLK_VALS .
106+ movdqu (MLK_IN, MLK_POS ), MLK_VALS
107+ // MLK_VALS : [ 63..48 | 47..32 | 31..16 | 15..0 ]
108+ pshufb MLK_SHUFFLE_IN_MASK, MLK_VALS
109+ // MLK_VALS : [ 47..32 | 39..24 | 23..8 | 15..0 ]
110+ movdqa MLK_VALS, MLK_TEMP1
106111 // temp: [ 47..32 | 39..24 | 23..8 | 15..0 ]
107- psrlw $4 , temp1
112+ psrlw $4 , MLK_TEMP1
108113 // temp: [ 47..36 | 39..28 | 23..12 | 15..4 ]
109- pblendw $0xAA , temp1, vals
110- // vals : [ 47..36 | 39..24 | 23..12 | 15..0]
111- pand and_mask, vals
112- // vals : [ 47..36 | 35..24 | 23..12 | 12..0]
113-
114- // 2. Compute an 8-bit value `good ` such that
115- // good [i] = vals [i] < MLKEM_Q ? 1 : 0, for i in [0, 7],
116- movdqa bound , temp0
117- pcmpgtw vals, temp0
118- pmovmskb temp0, good
119- pext pext_mask, good, good
120-
121- // 3. Shuffle the elements in `vals ` such that all good elements
114+ pblendw $0xAA , MLK_TEMP1, MLK_VALS
115+ // MLK_VALS : [ 47..36 | 39..24 | 23..12 | 15..0]
116+ pand MLK_AND_MASK, MLK_VALS
117+ // MLK_VALS : [ 47..36 | 35..24 | 23..12 | 12..0]
118+
119+ // 2. Compute an 8-bit value `MLK_GOOD ` such that
120+ // MLK_GOOD [i] = MLK_VALS [i] < MLKEM_Q ? 1 : 0, for i in [0, 7],
121+ movdqa MLK_BOUND, MLK_TEMP0
122+ pcmpgtw MLK_VALS, MLK_TEMP0
123+ pmovmskb MLK_TEMP0, MLK_GOOD
124+ pext MLK_PEXT_MASK, MLK_GOOD, MLK_GOOD
125+
126+ // 3. Shuffle the elements in `MLK_VALS ` such that all MLK_GOOD elements
122127 // are ordered consecutivelly, and store them.
123- movq good, table_idx
124- shl $4 , table_idx
125- movdqu (tab, table_idx ), shuffle_out_mask
126- pshufb shuffle_out_mask, vals
127- movdqu vals , (%rsp , cnt , 2 )
128+ movq MLK_GOOD, MLK_TABLE_IDX
129+ shl $4 , MLK_TABLE_IDX
130+ movdqu (MLK_TAB, MLK_TABLE_IDX ), MLK_SHUFFLE_OUT_MASK
131+ pshufb MLK_SHUFFLE_OUT_MASK, MLK_VALS
132+ movdqu MLK_VALS , (%rsp , MLK_CNT , 2 )
128133
129134 // Update the counter and check if we are done.
130- popcnt good, good
131- addq good, cnt
135+ popcnt MLK_GOOD, MLK_GOOD
136+ addq MLK_GOOD, MLK_CNT
132137
133- cmpq $256 , cnt
138+ cmpq $256 , MLK_CNT
134139 jnb rej_uniform_asm_final_copy
135140
136- addq $12 , pos
137- cmpq pos, len
141+ addq $12 , MLK_POS
142+ cmpq MLK_POS, MLK_LEN
138143 ja rej_uniform_asm_loop_start
139144
140145rej_uniform_asm_final_copy:
141- // Copy up to 256 values to the output: min(cnt , 256).
146+ // Copy up to 256 values to the output: min(MLK_CNT , 256).
142147 mov $256 , %rcx
143- cmp $256 , cnt
144- cmova %rcx , cnt
148+ cmp $256 , MLK_CNT
149+ cmova %rcx , MLK_CNT
145150
146151 movq %rsp , %rsi
147- movq cnt , %rcx
152+ movq MLK_CNT , %rcx
148153 shlq $1 , %rcx
149154 rep movsb
150155
@@ -155,23 +160,23 @@ rej_uniform_asm_end:
155160
156161/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
157162 * Don't modify by hand -- this is auto-generated by scripts/autogen. */
158- #undef in
159- #undef out
160- #undef len
161- #undef tab
162- #undef cnt
163- #undef ecnt
164- #undef pos
165- #undef good
166- #undef pext_mask
167- #undef table_idx
168- #undef bound
169- #undef temp0
170- #undef temp1
171- #undef vals
172- #undef shuffle_out_mask
173- #undef shuffle_in_mask
174- #undef and_mask
163+ #undef MLK_IN
164+ #undef MLK_OUT
165+ #undef MLK_LEN
166+ #undef MLK_TAB
167+ #undef MLK_CNT
168+ #undef MLK_ECNT
169+ #undef MLK_POS
170+ #undef MLK_GOOD
171+ #undef MLK_PEXT_MASK
172+ #undef MLK_TABLE_IDX
173+ #undef MLK_BOUND
174+ #undef MLK_TEMP0
175+ #undef MLK_TEMP1
176+ #undef MLK_VALS
177+ #undef MLK_SHUFFLE_OUT_MASK
178+ #undef MLK_SHUFFLE_IN_MASK
179+ #undef MLK_AND_MASK
175180#undef MLK_STACK_SIZE
176181
177182/* simpasm: footer-start */
0 commit comments