Skip to content

Commit 75e788a

Browse files
committed
Fix x1 mve keccak
1 parent 32b7252 commit 75e788a

4 files changed

Lines changed: 97 additions & 63 deletions

File tree

mlkem/src/fips202/keccakf1600.c

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,11 @@
3434
#define MLK_KECCAK_NROUNDS 24
3535
#define MLK_KECCAK_ROL(a, offset) ((a << offset) ^ (a >> (64 - offset)))
3636

37-
void mlk_keccakf1600_extract_bytes(uint64_t *state, unsigned char *data,
37+
MLK_STATIC_TESTABLE
38+
void mlk_keccakf1600_extract_bytes_c(uint64_t *state, unsigned char *data,
3839
unsigned offset, unsigned length)
3940
{
4041
unsigned i;
41-
#if defined(MLK_USE_FIPS202_X1_EXTRACT_BYTES_NATIVE)
42-
if(mlk_keccakf1600_extract_bytes_x1_native(state, data, offset, length) == MLK_NATIVE_FUNC_SUCCESS)
43-
{
44-
return;
45-
}
46-
#endif
4742
#if defined(MLK_SYS_LITTLE_ENDIAN)
4843
uint8_t *state_ptr = (uint8_t *)state + offset;
4944
for (i = 0; i < length; i++)
@@ -62,15 +57,23 @@ void mlk_keccakf1600_extract_bytes(uint64_t *state, unsigned char *data,
6257
#endif /* !MLK_SYS_LITTLE_ENDIAN */
6358
}
6459

65-
void mlk_keccakf1600_xor_bytes(uint64_t *state, const unsigned char *data,
66-
unsigned offset, unsigned length)
60+
void mlk_keccakf1600_extract_bytes(uint64_t *state, unsigned char *data,
61+
unsigned offset, unsigned length)
6762
{
68-
unsigned i;
69-
#if defined(MLK_USE_FIPS202_X1_XOR_BYTES_NATIVE)
70-
if (mlk_keccakf1600_xor_bytes_x1_native(state, data, offset, length) == MLK_NATIVE_FUNC_SUCCESS) {
63+
#if defined(MLK_USE_FIPS202_X1_EXTRACT_BYTES_NATIVE)
64+
if(mlk_keccakf1600_extract_bytes_x1_native(state, data, offset, length) == MLK_NATIVE_FUNC_SUCCESS)
65+
{
7166
return;
7267
}
7368
#endif
69+
mlk_keccakf1600_extract_bytes_c(state, data, offset, length);
70+
}
71+
72+
MLK_STATIC_TESTABLE
73+
void mlk_keccakf1600_xor_bytes_c(uint64_t *state, const unsigned char *data,
74+
unsigned offset, unsigned length)
75+
{
76+
unsigned i;
7477
#if defined(MLK_SYS_LITTLE_ENDIAN)
7578
uint8_t *state_ptr = (uint8_t *)state + offset;
7679
for (i = 0; i < length; i++)
@@ -89,6 +92,17 @@ void mlk_keccakf1600_xor_bytes(uint64_t *state, const unsigned char *data,
8992
#endif /* !MLK_SYS_LITTLE_ENDIAN */
9093
}
9194

95+
void mlk_keccakf1600_xor_bytes(uint64_t *state, const unsigned char *data,
96+
unsigned offset, unsigned length)
97+
{
98+
#if defined(MLK_USE_FIPS202_X1_XOR_BYTES_NATIVE)
99+
if (mlk_keccakf1600_xor_bytes_x1_native(state, data, offset, length) == MLK_NATIVE_FUNC_SUCCESS) {
100+
return;
101+
}
102+
#endif
103+
mlk_keccakf1600_xor_bytes_c(state, data, offset, length);
104+
}
105+
92106
void mlk_keccakf1600x4_extract_bytes(uint64_t *state, unsigned char *data0,
93107
unsigned char *data1, unsigned char *data2,
94108
unsigned char *data3, unsigned offset,

mlkem/src/fips202/native/armv81m/src/state_extract_bytes_x1_mve_asm.S

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@
3939
.endm
4040

4141
.balign 8
42-
.macro from_bit_interleaving_x1
42+
.macro from_bit_interleaving_x1 tmp
4343
// Input: q0 = [e0, o0, e1, o1]
4444
// Output: q0 = [d0l, d0h, d1l, d1h]
4545
// Clobbers: r0, q1, q2, q3, q4
4646
47-
mov r0, #0x0F0F
48-
vmsr p0, r0
47+
mov \tmp, #0x0F0F
48+
vmsr p0, \tmp
4949
// // q0.u16: [e0l, e0h, o0l, o0h, e1l, e1h, o1l, o1h]
5050
vrev32.u16 q1, q0 // q1.u16: [e0h, e0l, o0h, o0l, e1h, e1l, o1h, o1l]
5151
vrev64.u32 q2, q1 // q2.u16: [o0h, o0l, e0h, e0l, o1h, o1l, e1h, e1l]
@@ -58,8 +58,8 @@
5858
deinterleave_even q0, q2
5959
deinterleave_even q1, q2
6060
// Zero garbage bits
61-
mov r0, #0x55
62-
vdup.u8 q2, r0
61+
mov \tmp, #0x55
62+
vdup.u8 q2, \tmp
6363
vand.u32 q0, q0, q2
6464
vand.u32 q1, q1, q2
6565
// Merge vectors
@@ -120,6 +120,11 @@ MLK_ASM_FN_SYMBOL(keccak_f1600_x1_state_extract_bytes_asm)
120120
// length -= n
121121
subs length, length, nB
122122

123+
// Load state for the partial lane
124+
vldrw.u32 qd, [state], #16
125+
// Deinterleave to bytes
126+
from_bit_interleaving_x1 tmp
127+
// Predicated byte store of up to 16 bytes
123128
// calculate the predicates
124129
// mask = (1 << nB) - 1 over 8-bit lanes, then shift by 'off'.
125130
// vctp.8 sets p0[0..nB-1]=1 (others 0). We read it as an integer mask,
@@ -130,11 +135,6 @@ MLK_ASM_FN_SYMBOL(keccak_f1600_x1_state_extract_bytes_asm)
130135
// mask << offset
131136
lsl mask, mask, off
132137
vmsr p0, mask
133-
// Load state for the partial lane
134-
vldrw.u32 qd, [state], #16
135-
// Deinterleave to bytes
136-
from_bit_interleaving_x1
137-
// Predicated byte store of up to 16 bytes
138138
vpst
139139
vstrbt.u8 qd, [dp], #16
140140

@@ -154,7 +154,7 @@ keccak_f1600_x1_state_extract_bytes_asm_main_loop_start:
154154
// Load 16B (two lanes) from state and bump pointer
155155
vldrw.u32 qd, [state], #16
156156
// Deinterleave to bytes
157-
from_bit_interleaving_x1
157+
from_bit_interleaving_x1 tmp
158158
// Store 16B of output bytes (post-increment by 16)
159159
vstrw.u32 qd, [dp], #16
160160

@@ -165,16 +165,16 @@ keccak_f1600_x1_state_extract_bytes_asm_main_loop_end:
165165
// TAIL: if length remaining <8, write it at offset_in_lane=0
166166
// -------------------------------------------------------------------------
167167

168-
// length &= 7
169-
ands length, length, #7
168+
// length &= 15
169+
ands length, length, #15
170170
cmp length, #0
171171
beq keccak_f1600_x1_state_extract_bytes_asm_exit
172172

173-
// Tail via predicated byte stores like prologue, but off=0 (no base adjust)
174-
vctp.8 length
175173
// Load next state lane, deinterleave, store tail
176174
vldrw.u32 qd, [state], #16
177-
from_bit_interleaving_x1
175+
from_bit_interleaving_x1 tmp
176+
// Tail via predicated byte stores like prologue, but off=0 (no base adjust)
177+
vctp.8 length
178178
vpst
179179
vstrbt.u8 qd, [dp], #16
180180

mlkem/src/fips202/native/armv81m/src/state_xor_bytes_x1_mve_asm.S

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@
2828
.macro interleave_odds t, u
2929
vshl.u8 \u, \t, #2 // u = t[5..0],00
3030
vsri.u8 \t, \u, #1 // t = t[7],u[6..0] => t = t[7],t[5..0],0
31-
vshl.u8 \u, \t, #3 // stage 2 across nibbles
32-
vsri.u8 \t, \u, #2
33-
vshl.u8 \u, \t, #4 // stage 3 across bytes
34-
vsri.u8 \t, \u, #3
35-
vshl.u16 \u, \t, #8 // widen within halfwords
36-
vsri.u8 \t, \u, #4
37-
vshl.u32 \u, \t, #16 // widen within words
38-
vsri.u16 \t, \u, #8 // odd bits compacted; per 32b lane: lo16=bytes0..3, hi16=4..7
31+
vshl.u8 \u, \t, #3 // u = t[3..0],0000
32+
vsri.u8 \t, \u, #2 // t = t[7..6],u[5..0] => t = t[7],t[5],t[3..0],00
33+
vshl.u8 \u, \t, #4 // u = t[1..0],000000
34+
vsri.u8 \t, \u, #3 // t = t[7],t[5],t[3],u[4..0] => t = t[7],t[5],t[3],t[1..0],000
35+
// t16 = t[15],t[13],t[11],t[9..8],000,t[7],t[5],t[3],t[1..0],000
36+
vshl.u16 \u, \t, #8 // u16 = t[7],t[5],t[3],t[1..0],000
37+
vsri.u8 \t, \u, #4 // t16 = t[15,13,11,9,7,5,3,1]
38+
vshl.u32 \u, \t, #16 // u32 = t[15,13,11,9,7,5,3,1]
39+
vsri.u16 \t, \u, #8 // u16 = t[31,29,27,25,23,21,19,17,15,13,11,9,7,5,3,1]
3940
.endm
4041

4142
// interleave_evens: in-place SWAR bit permutation that compacts even-numbered
@@ -55,7 +56,7 @@
5556
.endm
5657

5758
.balign 8
58-
.macro to_bit_interleaving_x1
59+
.macro to_bit_interleaving_x1 tmp
5960
// NOTE: This macro clobbers r0, q0, q1, q2, q3
6061
// Inputs on entry:
6162
// q0 = [d0l, d0h, d1l, d1h] (Two complete 64-bit lanes in 32-bit chunks)
@@ -68,10 +69,10 @@
6869
vrev64.u32 q2, q1 // || d0l | d0h | d1l | d1h || e0l | e0h | e1l | e1h || e0h | e0l | e1h | e1l || X | X | X | X ||
6970
vsli.u32 q1, q2, #16 // || d0l | d0h | d1l | d1h || e0 | X | e1 | X || e0h | e0l | e1h | e1l || X | X | X | X ||
7071
interleave_odds q0, q3 // || o0l | o0h | o1l | o1h || e0 | X | e1 | X || e0h | e0l | e1h | e1l || X | X | X | X ||
71-
vrev64.u32 q0, q3 // || o0l | o0h | o1l | o1h || e0 | X | e1 | X || e0h | e0l | e1h | e1l || o0h | o0l | o1h | o1l ||
72+
vrev64.u32 q3, q0 // || o0l | o0h | o1l | o1h || e0 | X | e1 | X || e0h | e0l | e1h | e1l || o0h | o0l | o1h | o1l ||
7273
vsri.u32 q0, q3, #16 // || X | o0 | X | o1 || e0 | X | e1 | X || e0h | e0l | e1h | e1l || o0h | o0l | o1h | o1l ||
73-
mov r0, #0x0F0F
74-
vmsr p0, r0
74+
mov \tmp, #0x0F0F
75+
vmsr p0, \tmp
7576
vpsel q0, q1, q0 // || e0 | o0 | e1 | o1 || e0 | X | e1 | X || e0h | e0l | e1h | e1l || o0h | o0l | o1h | o1l ||
7677
.endm
7778

@@ -145,7 +146,7 @@ MLK_ASM_FN_SYMBOL(keccak_f1600_x1_state_xor_bytes_asm)
145146

146147
// Bit interleave
147148
// NOTE: q2,q3,q4 are dead here and not preserved.
148-
to_bit_interleaving_x1
149+
to_bit_interleaving_x1 tmp
149150

150151
vldrw.u32 qs, [state]
151152
veor qs, qs, qd
@@ -168,7 +169,7 @@ keccak_f1600_x1_state_xor_bytes_asm_main_loop_start:
168169
vldrw.u32 qd, [dp], #16
169170
// Bit interleave
170171
// NOTE: q2,q3,q4 are dead here and not preserved.
171-
to_bit_interleaving_x1
172+
to_bit_interleaving_x1 tmp
172173

173174
// XOR into state (stores post-increment state by 16)
174175
vldrw.u32 qs, [state]
@@ -182,26 +183,20 @@ keccak_f1600_x1_state_xor_bytes_asm_main_loop_end:
182183
// TAIL: if length remaining <8, absorb it at offset_in_lane=0
183184
// -------------------------------------------------------------------------
184185

185-
// length &= 7
186+
// length &= 15
186187
// Placeholder: if r6 == 0, done.
187-
ands length, length, #7
188+
ands length, length, #15
188189
cmp length, #0
189190
beq keccak_f1600_x1_state_xor_bytes_asm_exit
190191

191192
// Tail via predicated byte loads like prologue, but off=0 (no base adjust)
192193
vctp.8 length
193-
vctp.8 nB
194-
vmrs mask, p0
195-
// mask << offset
196-
lsl mask, mask, off
197-
vmsr p0, mask
198-
// now load the partial lanes
199194
vpst
200195
vldrbt.u8 qd, [dp], #16
201196

202197
// Bit interleave
203198
// NOTE: q2,q3,q4 are dead here and not preserved.
204-
to_bit_interleaving_x1
199+
to_bit_interleaving_x1 tmp
205200

206201
vldrw.u32 qs, [state]
207202
veor qs, qs, qd

test/src/test_unit.c

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ void mlk_polyvec_basemul_acc_montgomery_cached_c(
4343
const mlk_polyvec_mulcache *b_cache);
4444
void mlk_poly_mulcache_compute_c(mlk_poly_mulcache *x, const mlk_poly *a);
4545
void mlk_keccakf1600_permute_c(uint64_t *state);
46-
46+
void mlk_keccakf1600_xor_bytes_c(uint64_t *state, const unsigned char *data,
47+
unsigned offset, unsigned length);
48+
void mlk_keccakf1600_extract_bytes_c(uint64_t *state, unsigned char *data,
49+
unsigned offset, unsigned length);
4750
#define CHECK(x) \
4851
do \
4952
{ \
@@ -638,21 +641,43 @@ static int test_native_polyvec_basemul(void)
638641
#endif /* MLK_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */
639642

640643
#ifdef MLK_USE_FIPS202_X1_NATIVE
641-
static int test_keccakf1600_permute(void)
644+
#define MAX_RATE 136
645+
static int test_keccakf1600_xor_permute_extract(void)
642646
{
643-
uint64_t state[MLK_KECCAK_LANES];
644-
uint64_t state_ref[MLK_KECCAK_LANES];
647+
uint64_t input[MLK_KECCAK_LANES];
648+
uint64_t state_native[MLK_KECCAK_LANES];
649+
uint64_t state_c[MLK_KECCAK_LANES];
650+
uint64_t output_native[MLK_KECCAK_LANES];
651+
uint64_t output_c[MLK_KECCAK_LANES];
652+
uint8_t xor_offset, xor_length, ext_offset, ext_length;
645653
int i;
646654

647655
for (i = 0; i < NUM_RANDOM_TESTS; i++)
648656
{
649-
randombytes((uint8_t *)state, sizeof(state));
650-
memcpy(state_ref, state, sizeof(state));
651-
652-
mlk_keccakf1600_permute(state);
653-
mlk_keccakf1600_permute_c(state_ref);
654-
655-
CHECK(compare_u64_arrays(state, state_ref, MLK_KECCAK_LANES,
657+
randombytes(&xor_offset,1);
658+
randombytes(&xor_length,1);
659+
xor_offset = xor_offset % MAX_RATE;
660+
xor_length = (uint8_t)(1 + (xor_length % (MAX_RATE - xor_offset)));
661+
randombytes(&ext_offset, 1);
662+
randombytes(&ext_length, 1);
663+
ext_offset = ext_offset % MAX_RATE;
664+
ext_length = (uint8_t)(1 + (ext_length % (MAX_RATE - ext_offset)));
665+
666+
randombytes((uint8_t *)input, xor_length);
667+
memset(state_native, 0, sizeof(state_native));
668+
memset(output_native, 0, sizeof(output_native));
669+
670+
mlk_keccakf1600_xor_bytes(state_native, (uint8_t *)input, xor_offset, xor_length);
671+
mlk_keccakf1600_permute(state_native);
672+
mlk_keccakf1600_extract_bytes(state_native, (uint8_t *)output_native, ext_offset, ext_length);
673+
674+
memset(state_c, 0, sizeof(state_c));
675+
memset(output_c, 0, sizeof(output_c));
676+
mlk_keccakf1600_xor_bytes_c(state_c, (uint8_t *)input, xor_offset, xor_length);
677+
mlk_keccakf1600_permute_c(state_c);
678+
mlk_keccakf1600_extract_bytes_c(state_c, (uint8_t *)output_c, ext_offset, ext_length);
679+
680+
CHECK(compare_u64_arrays(output_native, output_c, MLK_KECCAK_LANES,
656681
"keccakf1600_permute"));
657682
}
658683

@@ -722,7 +747,7 @@ static int test_backend_units(void)
722747
#endif
723748

724749
#ifdef MLK_USE_FIPS202_X1_NATIVE
725-
CHECK(test_keccakf1600_permute() == 0);
750+
CHECK(test_keccakf1600_xor_permute_extract() == 0);
726751
#endif
727752

728753
#ifdef MLK_USE_FIPS202_X4_NATIVE

0 commit comments

Comments
 (0)