Skip to content

Commit 885bca1

Browse files
committed
AArch64: Add Neon polyw1_pack to AArch64 native backend
Add AArch64 assembly implementations of polyw1_pack for both GAMMA2 variants using TBL-based byte extraction from 32-bit coefficient lanes. Signed-off-by: Matthias J. Kannwischer <matthias@kannwischer.eu>
1 parent 1bb503b commit 885bca1

File tree

15 files changed

+807
-0
lines changed

15 files changed

+807
-0
lines changed

dev/aarch64_clean/meta.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#define MLD_USE_NATIVE_POLY_CHKNORM
2222
#define MLD_USE_NATIVE_POLYZ_UNPACK_17
2323
#define MLD_USE_NATIVE_POLYZ_UNPACK_19
24+
#define MLD_USE_NATIVE_POLYW1_PACK_32
25+
#define MLD_USE_NATIVE_POLYW1_PACK_88
2426
#define MLD_USE_NATIVE_POINTWISE_MONTGOMERY
2527
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4
2628
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5
@@ -198,6 +200,44 @@ static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *buf)
198200
#endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \
199201
|| MLD_CONFIG_PARAMETER_SET == 87 */
200202

203+
#if defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \
204+
(MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)
205+
MLD_MUST_CHECK_RETURN_VALUE
206+
static MLD_INLINE int mld_polyw1_pack_32_native(uint8_t *r, const int32_t *a)
207+
{
208+
mld_polyw1_pack_32_asm(r, a);
209+
return MLD_NATIVE_FUNC_SUCCESS;
210+
}
211+
#endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \
212+
|| MLD_CONFIG_PARAMETER_SET == 87 */
213+
214+
#if defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLD_CONFIG_PARAMETER_SET == 44
215+
/* Table of constants for polyw1_pack_88_asm:
216+
* [0:15] v_shifts: USHL shift amounts {0, 6, 12, 18} as .4s
217+
* [16:31] v_tbl0: TBL indices for out0 from {v16, v17}
218+
* [32:47] v_tbl1: TBL indices for out1 from {v17, v18}
219+
* [48:63] v_tbl2: TBL indices for out2 from {v18, v19} */
220+
/* clang-format off */
221+
MLD_ALIGN static const uint8_t mld_polyw1_pack_88_consts[] = {
222+
/* v_shifts: {0, 6, 12, 18} as uint32_t little-endian */
223+
0, 0, 0, 0, 6, 0, 0, 0, 12, 0, 0, 0, 18, 0, 0, 0,
224+
/* v_tbl0: {0,1,2, 4,5,6, 8,9,10, 12,13,14, 16,17,18, 20} */
225+
0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20,
226+
/* v_tbl1: {5,6, 8,9,10, 12,13,14, 16,17,18, 20,21,22, 24,25} */
227+
5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25,
228+
/* v_tbl2: {10, 12,13,14, 16,17,18, 20,21,22, 24,25,26, 28,29,30} */
229+
10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30,
230+
};
231+
/* clang-format on */
232+
MLD_MUST_CHECK_RETURN_VALUE
233+
static MLD_INLINE int mld_polyw1_pack_88_native(uint8_t *r, const int32_t *a)
234+
{
235+
mld_polyw1_pack_88_asm(r, a, mld_polyw1_pack_88_consts);
236+
return MLD_NATIVE_FUNC_SUCCESS;
237+
}
238+
#endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 44 \
239+
*/
240+
201241
MLD_MUST_CHECK_RETURN_VALUE
202242
static MLD_INLINE int mld_poly_pointwise_montgomery_native(
203243
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],

dev/aarch64_clean/src/arith_native_aarch64.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf,
105105
void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf,
106106
const uint8_t *indices);
107107

108+
#define mld_polyw1_pack_32_asm MLD_NAMESPACE(polyw1_pack_32_asm)
109+
void mld_polyw1_pack_32_asm(uint8_t *r, const int32_t *a);
110+
111+
#define mld_polyw1_pack_88_asm MLD_NAMESPACE(polyw1_pack_88_asm)
112+
void mld_polyw1_pack_88_asm(uint8_t *r, const int32_t *a, const uint8_t *table);
113+
108114
#define mld_poly_pointwise_montgomery_asm \
109115
MLD_NAMESPACE(poly_pointwise_montgomery_asm)
110116
void mld_poly_pointwise_montgomery_asm(int32_t *, const int32_t *,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
#include "../../../common.h"
7+
#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \
8+
(defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \
9+
MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)
10+
/* simpasm: header-end */
11+
12+
/*
13+
* polyw1_pack_32: Pack w1 polynomial for GAMMA2 = (Q-1)/32.
14+
*
15+
* Each coefficient is in [0, 15] (4 bits) stored in a 32-bit word.
16+
* Pack 2 coefficients per byte: r[i] = a[2i] | (a[2i+1] << 4)
17+
* 256 coefficients -> 128 output bytes.
18+
*
19+
* UZP1 narrowing chain (32->16->8 bit) extracts the low byte from
20+
* each coefficient; UZP1/UZP2 separate even/odd coefficients;
21+
* SLI shifts and inserts the odd nibbles.
22+
*
23+
* 4x unrolled, 2 iterations for 256 coefficients.
24+
*/
25+
26+
output .req x0
27+
input .req x1
28+
count .req x2
29+
30+
.text
31+
.global MLD_ASM_NAMESPACE(polyw1_pack_32_asm)
32+
.balign 4
33+
MLD_ASM_FN_SYMBOL(polyw1_pack_32_asm)
34+
35+
mov count, #(256 / (32 * 4))
36+
37+
polyw1_pack_32_loop:
38+
39+
/* Block 0: coefficients 0-31 */
40+
ldp q0, q1, [input], #512
41+
ldp q2, q3, [input, #(32 - 512)]
42+
ldp q4, q5, [input, #(64 - 512)]
43+
ldp q6, q7, [input, #(96 - 512)]
44+
uzp1 v0.8h, v0.8h, v1.8h
45+
uzp1 v2.8h, v2.8h, v3.8h
46+
uzp1 v4.8h, v4.8h, v5.8h
47+
uzp1 v6.8h, v6.8h, v7.8h
48+
uzp1 v0.16b, v0.16b, v2.16b
49+
uzp1 v4.16b, v4.16b, v6.16b
50+
uzp1 v16.16b, v0.16b, v4.16b
51+
uzp2 v0.16b, v0.16b, v4.16b
52+
sli v16.16b, v0.16b, #4
53+
54+
/* Block 1: coefficients 32-63 */
55+
ldp q0, q1, [input, #(128 - 512)]
56+
ldp q2, q3, [input, #(160 - 512)]
57+
ldp q4, q5, [input, #(192 - 512)]
58+
ldp q6, q7, [input, #(224 - 512)]
59+
uzp1 v0.8h, v0.8h, v1.8h
60+
uzp1 v2.8h, v2.8h, v3.8h
61+
uzp1 v4.8h, v4.8h, v5.8h
62+
uzp1 v6.8h, v6.8h, v7.8h
63+
uzp1 v0.16b, v0.16b, v2.16b
64+
uzp1 v4.16b, v4.16b, v6.16b
65+
uzp1 v17.16b, v0.16b, v4.16b
66+
uzp2 v0.16b, v0.16b, v4.16b
67+
sli v17.16b, v0.16b, #4
68+
69+
/* Block 2: coefficients 64-95 */
70+
ldp q0, q1, [input, #(256 - 512)]
71+
ldp q2, q3, [input, #(288 - 512)]
72+
ldp q4, q5, [input, #(320 - 512)]
73+
ldp q6, q7, [input, #(352 - 512)]
74+
uzp1 v0.8h, v0.8h, v1.8h
75+
uzp1 v2.8h, v2.8h, v3.8h
76+
uzp1 v4.8h, v4.8h, v5.8h
77+
uzp1 v6.8h, v6.8h, v7.8h
78+
uzp1 v0.16b, v0.16b, v2.16b
79+
uzp1 v4.16b, v4.16b, v6.16b
80+
uzp1 v18.16b, v0.16b, v4.16b
81+
uzp2 v0.16b, v0.16b, v4.16b
82+
sli v18.16b, v0.16b, #4
83+
84+
/* Block 3: coefficients 96-127 */
85+
ldp q0, q1, [input, #(384 - 512)]
86+
ldp q2, q3, [input, #(416 - 512)]
87+
ldp q4, q5, [input, #(448 - 512)]
88+
ldp q6, q7, [input, #(480 - 512)]
89+
uzp1 v0.8h, v0.8h, v1.8h
90+
uzp1 v2.8h, v2.8h, v3.8h
91+
uzp1 v4.8h, v4.8h, v5.8h
92+
uzp1 v6.8h, v6.8h, v7.8h
93+
uzp1 v0.16b, v0.16b, v2.16b
94+
uzp1 v4.16b, v4.16b, v6.16b
95+
uzp1 v19.16b, v0.16b, v4.16b
96+
uzp2 v0.16b, v0.16b, v4.16b
97+
sli v19.16b, v0.16b, #4
98+
99+
st1 {v16.16b, v17.16b, v18.16b, v19.16b}, [output], #64
100+
101+
subs count, count, #1
102+
bne polyw1_pack_32_loop
103+
104+
ret
105+
106+
.unreq output
107+
.unreq input
108+
.unreq count
109+
/* simpasm: footer-start */
110+
#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED && \
111+
(MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \
112+
|| MLD_CONFIG_PARAMETER_SET == 87) */
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
#include "../../../common.h"
7+
#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \
8+
(defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLD_CONFIG_PARAMETER_SET == 44)
9+
/* simpasm: header-end */
10+
11+
/*
12+
* polyw1_pack_88: Pack w1 polynomial for GAMMA2 = (Q-1)/88.
13+
*
14+
* Each coefficient is in [0, 43] (6 bits) stored in a 32-bit word.
15+
* Pack 4 coefficients into 3 bytes:
16+
* r[3i+0] = a[4i+0] | (a[4i+1] << 6)
17+
* r[3i+1] = (a[4i+1] >> 2) | (a[4i+2] << 4)
18+
* r[3i+2] = (a[4i+2] >> 4) | (a[4i+3] << 2)
19+
* 256 coefficients -> 192 output bytes.
20+
*
21+
* Each group of 4 coefficients in a .4s vector is shifted to its
22+
* bit position using USHL, then reduced with ADDP to form one
23+
* 24-bit packed value per 32-bit lane.
24+
*
25+
* Three 2-register TBL instructions then extract the useful 3 bytes
26+
* from each 32-bit lane across pairs of adjacent result vectors,
27+
* producing 3 contiguous 16-byte output vectors (48 bytes total).
28+
*
29+
* 4x unrolled, 4 iterations for 256 coefficients.
30+
*/
31+
32+
output .req x0
33+
input .req x1
34+
table .req x2
35+
count .req x3
36+
37+
v_shifts .req v24
38+
v_tbl0 .req v25
39+
v_tbl1 .req v26
40+
v_tbl2 .req v27
41+
42+
.text
43+
.global MLD_ASM_NAMESPACE(polyw1_pack_88_asm)
44+
.balign 4
45+
MLD_ASM_FN_SYMBOL(polyw1_pack_88_asm)
46+
47+
/* Load constants from table pointer (x2):
48+
* [0:15] = v_shifts.4s = {0, 6, 12, 18}
49+
* [16:31] = v_tbl0: TBL indices for out0 from {v16, v17}
50+
* [32:47] = v_tbl1: TBL indices for out1 from {v17, v18}
51+
* [48:63] = v_tbl2: TBL indices for out2 from {v18, v19} */
52+
ldp q24, q25, [table]
53+
ldp q26, q27, [table, #32]
54+
55+
mov count, #(256 / (16 * 4))
56+
57+
polyw1_pack_88_loop:
58+
59+
/* Block 0: coefficients 0-15 */
60+
ldp q0, q1, [input], #256
61+
ldp q2, q3, [input, #(32 - 256)]
62+
ushl v0.4s, v0.4s, v_shifts.4s
63+
ushl v1.4s, v1.4s, v_shifts.4s
64+
ushl v2.4s, v2.4s, v_shifts.4s
65+
ushl v3.4s, v3.4s, v_shifts.4s
66+
addp v0.4s, v0.4s, v1.4s
67+
addp v2.4s, v2.4s, v3.4s
68+
addp v16.4s, v0.4s, v2.4s
69+
70+
/* Block 1: coefficients 16-31 */
71+
ldp q0, q1, [input, #(64 - 256)]
72+
ldp q2, q3, [input, #(96 - 256)]
73+
ushl v0.4s, v0.4s, v_shifts.4s
74+
ushl v1.4s, v1.4s, v_shifts.4s
75+
ushl v2.4s, v2.4s, v_shifts.4s
76+
ushl v3.4s, v3.4s, v_shifts.4s
77+
addp v0.4s, v0.4s, v1.4s
78+
addp v2.4s, v2.4s, v3.4s
79+
addp v17.4s, v0.4s, v2.4s
80+
81+
/* Block 2: coefficients 32-47 */
82+
ldp q0, q1, [input, #(128 - 256)]
83+
ldp q2, q3, [input, #(160 - 256)]
84+
ushl v0.4s, v0.4s, v_shifts.4s
85+
ushl v1.4s, v1.4s, v_shifts.4s
86+
ushl v2.4s, v2.4s, v_shifts.4s
87+
ushl v3.4s, v3.4s, v_shifts.4s
88+
addp v0.4s, v0.4s, v1.4s
89+
addp v2.4s, v2.4s, v3.4s
90+
addp v18.4s, v0.4s, v2.4s
91+
92+
/* Block 3: coefficients 48-63 */
93+
ldp q0, q1, [input, #(192 - 256)]
94+
ldp q2, q3, [input, #(224 - 256)]
95+
ushl v0.4s, v0.4s, v_shifts.4s
96+
ushl v1.4s, v1.4s, v_shifts.4s
97+
ushl v2.4s, v2.4s, v_shifts.4s
98+
ushl v3.4s, v3.4s, v_shifts.4s
99+
addp v0.4s, v0.4s, v1.4s
100+
addp v2.4s, v2.4s, v3.4s
101+
addp v19.4s, v0.4s, v2.4s
102+
103+
/* Compact + splice into 3 output vectors */
104+
tbl v20.16b, {v16.16b, v17.16b}, v_tbl0.16b
105+
tbl v21.16b, {v17.16b, v18.16b}, v_tbl1.16b
106+
tbl v22.16b, {v18.16b, v19.16b}, v_tbl2.16b
107+
108+
st1 {v20.16b, v21.16b, v22.16b}, [output], #48
109+
110+
subs count, count, #1
111+
bne polyw1_pack_88_loop
112+
113+
ret
114+
115+
.unreq output
116+
.unreq input
117+
.unreq table
118+
.unreq count
119+
.unreq v_shifts
120+
.unreq v_tbl0
121+
.unreq v_tbl1
122+
.unreq v_tbl2
123+
/* simpasm: footer-start */
124+
#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED && \
125+
(MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 44) \
126+
*/

dev/aarch64_opt/meta.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#define MLD_USE_NATIVE_POLY_CHKNORM
2222
#define MLD_USE_NATIVE_POLYZ_UNPACK_17
2323
#define MLD_USE_NATIVE_POLYZ_UNPACK_19
24+
#define MLD_USE_NATIVE_POLYW1_PACK_32
25+
#define MLD_USE_NATIVE_POLYW1_PACK_88
2426
#define MLD_USE_NATIVE_POINTWISE_MONTGOMERY
2527
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4
2628
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5
@@ -198,6 +200,44 @@ static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *buf)
198200
#endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \
199201
|| MLD_CONFIG_PARAMETER_SET == 87 */
200202

203+
#if defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \
204+
(MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)
205+
MLD_MUST_CHECK_RETURN_VALUE
206+
static MLD_INLINE int mld_polyw1_pack_32_native(uint8_t *r, const int32_t *a)
207+
{
208+
mld_polyw1_pack_32_asm(r, a);
209+
return MLD_NATIVE_FUNC_SUCCESS;
210+
}
211+
#endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \
212+
|| MLD_CONFIG_PARAMETER_SET == 87 */
213+
214+
#if defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLD_CONFIG_PARAMETER_SET == 44
215+
/* Table of constants for polyw1_pack_88_asm:
216+
* [0:15] v_shifts: USHL shift amounts {0, 6, 12, 18} as .4s
217+
* [16:31] v_tbl0: TBL indices for out0 from {v16, v17}
218+
* [32:47] v_tbl1: TBL indices for out1 from {v17, v18}
219+
* [48:63] v_tbl2: TBL indices for out2 from {v18, v19} */
220+
/* clang-format off */
221+
MLD_ALIGN static const uint8_t mld_polyw1_pack_88_consts[] = {
222+
/* v_shifts: {0, 6, 12, 18} as uint32_t little-endian */
223+
0, 0, 0, 0, 6, 0, 0, 0, 12, 0, 0, 0, 18, 0, 0, 0,
224+
/* v_tbl0: {0,1,2, 4,5,6, 8,9,10, 12,13,14, 16,17,18, 20} */
225+
0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20,
226+
/* v_tbl1: {5,6, 8,9,10, 12,13,14, 16,17,18, 20,21,22, 24,25} */
227+
5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25,
228+
/* v_tbl2: {10, 12,13,14, 16,17,18, 20,21,22, 24,25,26, 28,29,30} */
229+
10, 12, 13, 14, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30,
230+
};
231+
/* clang-format on */
232+
MLD_MUST_CHECK_RETURN_VALUE
233+
static MLD_INLINE int mld_polyw1_pack_88_native(uint8_t *r, const int32_t *a)
234+
{
235+
mld_polyw1_pack_88_asm(r, a, mld_polyw1_pack_88_consts);
236+
return MLD_NATIVE_FUNC_SUCCESS;
237+
}
238+
#endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 44 \
239+
*/
240+
201241
MLD_MUST_CHECK_RETURN_VALUE
202242
static MLD_INLINE int mld_poly_pointwise_montgomery_native(
203243
int32_t out[MLDSA_N], const int32_t in0[MLDSA_N],

dev/aarch64_opt/src/arith_native_aarch64.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf,
105105
void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf,
106106
const uint8_t *indices);
107107

108+
#define mld_polyw1_pack_32_asm MLD_NAMESPACE(polyw1_pack_32_asm)
109+
void mld_polyw1_pack_32_asm(uint8_t *r, const int32_t *a);
110+
111+
#define mld_polyw1_pack_88_asm MLD_NAMESPACE(polyw1_pack_88_asm)
112+
void mld_polyw1_pack_88_asm(uint8_t *r, const int32_t *a, const uint8_t *table);
113+
108114
#define mld_poly_pointwise_montgomery_asm \
109115
MLD_NAMESPACE(poly_pointwise_montgomery_asm)
110116
void mld_poly_pointwise_montgomery_asm(int32_t *, const int32_t *,

0 commit comments

Comments
 (0)