@@ -1069,6 +1069,65 @@ def gen_aarch64_rej_uniform_table():
10691069 update_file ("dev/aarch64_clean/src/rej_uniform_table.c" , "\n " .join (gen ()))
10701070
10711071
1072+ def gen_aarch64_polyz_unpack_indices (bit_width ):
1073+ """Generate Neon TBL index bytes for polyz_unpack.
1074+
1075+ Each loop iteration loads bit_width * 16 / 8 packed bytes into
1076+ registers v0, v1, v2 and unpacks 16 coefficients. Coefficients
1077+ are unpacked in 4 groups of 4, one group per TBL/TBL2.
1078+
1079+ Each coefficient is extracted from 3 consecutive bytes.
1080+ Groups 0,1 index from v0 (and v1); groups 2,3 index from v1 (and v2),
1081+ so their byte offsets are shifted by -16.
1082+ """
1083+ for group in range (4 ):
1084+ base_coeff = group * 4
1085+ reg_offset = 0 if group < 2 else 16
1086+ for coeff in range (4 ):
1087+ i = base_coeff + coeff
1088+ byte_start = (i * bit_width ) // 8 - reg_offset
1089+ yield byte_start
1090+ yield byte_start + 1
1091+ yield byte_start + 2
1092+ yield 255
1093+
1094+
1095+ def gen_aarch64_polyz_unpack_table ():
1096+ def format_row (vals ):
1097+ return ", " .join (f"{ v :>3} " for v in vals ) + ","
1098+
1099+ def gen ():
1100+ yield from gen_header ()
1101+ yield '#include "../../../common.h"'
1102+ yield ""
1103+ yield "#if defined(MLD_ARITH_BACKEND_AARCH64) && \\ "
1104+ yield " !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)"
1105+ yield ""
1106+ yield "#include <stdint.h>"
1107+ yield '#include "arith_native_aarch64.h"'
1108+ yield ""
1109+ yield "/* Table of indices used for tbl instructions in polyz_unpack_{17,19}."
1110+ yield " * See autogen for details. */"
1111+ yield ""
1112+ for gamma1_bits in [17 , 19 ]:
1113+ bit_width = gamma1_bits + 1
1114+ indices = list (gen_aarch64_polyz_unpack_indices (bit_width ))
1115+ yield f"MLD_ALIGN const uint8_t mld_polyz_unpack_{ gamma1_bits } _indices[] = {{"
1116+ for row_start in range (0 , len (indices ), 16 ):
1117+ yield " " + format_row (indices [row_start : row_start + 16 ])
1118+ yield "};"
1119+ yield ""
1120+ yield "#else /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */"
1121+ yield ""
1122+ yield "MLD_EMPTY_CU(aarch64_polyz_unpack_table)"
1123+ yield ""
1124+ yield "#endif /* !(MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED) */"
1125+ yield ""
1126+
1127+ update_file ("dev/aarch64_opt/src/polyz_unpack_table.c" , "\n " .join (gen ()))
1128+ update_file ("dev/aarch64_clean/src/polyz_unpack_table.c" , "\n " .join (gen ()))
1129+
1130+
10721131def gen_avx2_rej_uniform_table_rows ():
10731132 # The index into the lookup table is an 8-bit bitmap, i.e. a number 0..255.
10741133 # Conceptually, the table entry at index i is a vector of 8 16-bit values, of
@@ -3189,6 +3248,7 @@ def _main():
31893248 gen_aarch64_zeta_file ()
31903249 gen_aarch64_rej_uniform_table ()
31913250 gen_aarch64_rej_uniform_eta_table ()
3251+ gen_aarch64_polyz_unpack_table ()
31923252 gen_avx2_hol_light_zeta_file ()
31933253 gen_avx2_zeta_file ()
31943254 gen_avx2_rej_uniform_table ()
0 commit comments