|
5 | 5 | /* High-level operations for the ML-DSA key generation function. */ |
6 | 6 |
|
7 | 7 | .globl sample_s |
| 8 | +.globl compute_t |
8 | 9 |
|
9 | 10 | .text |
10 | 11 |
|
@@ -87,3 +88,176 @@ sample_s: |
87 | 88 | /* End of loop */ |
88 | 89 |
|
89 | 90 | ret |
| 91 | + |
| 92 | +/** |
| 93 | + * Compute the T vector. |
| 94 | + * |
| 95 | + * This routine computes T = INTT(A * NTT(S1)) + S2 which is a 8x7 |
| 96 | + * matrix-vector multiplication followed by vector addition. The individual |
| 97 | + * polynomials of A, S1 and S2 are generated and decoded on-the-fly through |
| 98 | + * `expand_a` and `decode_s` respectively. `expand_a` requires a 34-byte seed |
| 99 | + * RHO (in a 64-byte region). |
| 100 | + * |
| 101 | + * The secret vectors S1 and S2 are assumed to be provided Boolean shares in |
| 102 | + * encoded form (2 * 672 bytes, 2 * 768 bytes). The resulting vector is T |
| 103 | + * is returned in two arithmetic shares (2 * 8192 bytes). |
| 104 | + * |
| 105 | + * Three polynomial slots are required for the storage of intermediate results. |
| 106 | + * |
| 107 | + * @param[in] x2: DMEM address of the seed RHO. |
| 108 | + * @param[in] x3: DMEM address of the first Boolean share of the encoded S1. |
| 109 | + * @param[in] x4: DMEM address of the second Boolean share of the encoded S1. |
| 110 | + * @param[in] x5: DMEM address of the first Boolean share of the encoded S2. |
| 111 | + * @param[in] x6: DMEM address of the second Boolean share of the encoded S2. |
| 112 | + * @param[in] x7: DMEM address of the first arithmetic share of T. |
| 113 | + * @param[in] x8: DMEM address of the first arithmetic share of T. |
| 114 | + * @param[in] x9: DMEM address of polynomial slot 0 (1024 bytes). |
| 115 | + * @param[in] x10: DMEM address of polynomial slot 1 (1024 bytes). |
| 116 | + * @param[in] x11: DMEM address of polynomial slot 2 (1024 bytes). |
| 117 | + */ |
| 118 | +compute_t: |
| 119 | + /* Prepare DMEM address registers. */ |
| 120 | + addi x12, x2, 0 /* RHO */ |
| 121 | + addi x13, x3, 0 /* S1_0_enc (share 0) */ |
| 122 | + addi x14, x4, 0 /* S1_1_enc (share 1) */ |
| 123 | + addi x15, x5, 0 /* S2_0_enc (share 0) */ |
| 124 | + addi x16, x6, 0 /* S2_1_enc (share 1) */ |
| 125 | + |
| 126 | + /* Loop indices for `expand_a`. */ |
| 127 | + addi x17, x0, 0 /* r */ |
| 128 | + addi x18, x0, 0 /* s */ |
| 129 | + |
| 130 | + /* Zeroize the vector slots. */ |
| 131 | + addi x20, x7, 0 |
| 132 | + addi x21, x0, 256 |
| 133 | + jal x1, zeroize |
| 134 | + |
| 135 | + addi x20, x8, 0 |
| 136 | + addi x21, x0, 256 |
| 137 | + jal x1, zeroize |
| 138 | + |
| 139 | + /* |
| 140 | + * The matrix-vector multiplication proceeds in column-major order: |
| 141 | + * |
| 142 | + * for s in [0, 6]: |
| 143 | + * S1_0, S1_1 = decode_s(S1_0_enc[s], S1_1_enc[s]) |
| 144 | + * X0, X1 = NTT(S1_0), NTT(S1_1) |
| 145 | + * for r in [0, 7]: |
| 146 | + * A = expand_a(RHO, r, s) |
| 147 | + * T_0[r] += A * X0 |
| 148 | + * T_1[r] += A * X1 |
| 149 | + * end for |
| 150 | + * end for |
| 151 | + */ |
| 152 | + |
| 153 | + loopi 7, 38 |
| 154 | + /* X0, X1 = decode_s(S1_0_enc[s], S1_1_enc[s]) (poly slots 0, 1). */ |
| 155 | + addi x2, x13, 0 |
| 156 | + addi x3, x14, 0 |
| 157 | + addi x4, x9, 0 |
| 158 | + addi x5, x10, 0 |
| 159 | + jal x1, decode_s |
| 160 | + |
| 161 | + /* X0 = NTT(S1_0). */ |
| 162 | + addi x2, x9, 0 |
| 163 | + addi x3, x9, 0 |
| 164 | + jal x1, ntt |
| 165 | + |
| 166 | + /* X1 = NTT(S1_1). */ |
| 167 | + addi x2, x10, 0 |
| 168 | + addi x3, x10, 0 |
| 169 | + jal x1, ntt |
| 170 | + |
| 171 | + loopi 8, 18 |
| 172 | + /* A = expand_a(RHO, r, s) (poly slot 2). */ |
| 173 | + addi x2, x11, 0 |
| 174 | + addi x3, x12, 0 |
| 175 | + addi x4, x17, 0 |
| 176 | + addi x5, x18, 0 |
| 177 | + jal x1, expand_a |
| 178 | + |
| 179 | + /* T_0[r] += A * X0 = A * NTT(S1_0). */ |
| 180 | + addi x2, x11, 0 |
| 181 | + addi x3, x9, 0 |
| 182 | + addi x4, x7, 0 |
| 183 | + addi x5, x7, 0 |
| 184 | + jal x1, poly_mul_add |
| 185 | + |
| 186 | + /* T_1[r] += A * X1 = A * NTT(S1_1). */ |
| 187 | + addi x2, x11, 0 |
| 188 | + addi x3, x10, 0 |
| 189 | + addi x4, x8, 0 |
| 190 | + addi x5, x8, 0 |
| 191 | + jal x1, poly_mul_add |
| 192 | + |
| 193 | + /* Increment r and advance output addresses. */ |
| 194 | + addi x7, x7, 1024 |
| 195 | + addi x8, x8, 1024 |
| 196 | + addi x17, x17, 1 |
| 197 | + /* End of loop */ |
| 198 | + |
| 199 | + /* Reset r and increment s. */ |
| 200 | + addi x17, x0, 0 |
| 201 | + addi x18, x18, 1 |
| 202 | + |
| 203 | + /* Reset the output addresses, i.e., subtract 8192. */ |
| 204 | + addi x20, x0, 1024 |
| 205 | + slli x20, x20, 3 |
| 206 | + sub x7, x7, x20 |
| 207 | + sub x8, x8, x20 |
| 208 | + |
| 209 | + /* Advance S1_0_enc and S1_1_enc pointers. */ |
| 210 | + addi x13, x13, 96 |
| 211 | + addi x14, x14, 96 |
| 212 | + /* End of loop */ |
| 213 | + |
| 214 | + /* |
| 215 | + * Vector-vector addition: |
| 216 | + * |
| 217 | + * for r in [0, 7]: |
| 218 | + * T_0[r] = INTT(T_0[r]) |
| 219 | + * T_1[r] = INTT(T_1[r]) |
| 220 | + * X0, X1 = decode_s(S2_0_enc[r], S2_1_enc[r]) |
| 221 | + * T_0[r] += X0 |
| 222 | + * T_1[r] += X1 |
| 223 | + * end for |
| 224 | + */ |
| 225 | + |
| 226 | + loopi 8, 23 |
| 227 | + /* T_0[r] = INTT(T_0[r]). */ |
| 228 | + addi x2, x7, 0 |
| 229 | + addi x3, x7, 0 |
| 230 | + jal x1, intt |
| 231 | + |
| 232 | + /* T_1[r] = INTT(T_1[r]). */ |
| 233 | + addi x2, x8, 0 |
| 234 | + addi x3, x8, 0 |
| 235 | + jal x1, intt |
| 236 | + |
| 237 | + /* X0, X1 = decode_s(S2_0_enc[r], S2_1_enc[r]) (poly slots 0, 1). */ |
| 238 | + addi x2, x15, 0 |
| 239 | + addi x3, x16, 0 |
| 240 | + addi x4, x9, 0 |
| 241 | + addi x5, x10, 0 |
| 242 | + jal x1, decode_s |
| 243 | + |
| 244 | + /* T_0[r] += X0. */ |
| 245 | + addi x2, x7, 0 |
| 246 | + addi x3, x9, 0 |
| 247 | + addi x4, x7, 0 |
| 248 | + jal x1, poly_add |
| 249 | + |
| 250 | + /* T_1[r] += X1. */ |
| 251 | + addi x2, x8, 0 |
| 252 | + addi x3, x10, 0 |
| 253 | + addi x4, x8, 0 |
| 254 | + jal x1, poly_add |
| 255 | + |
| 256 | + /* Advance output and S2_enc pointers. */ |
| 257 | + addi x7, x7, 1024 |
| 258 | + addi x8, x8, 1024 |
| 259 | + addi x15, x15, 96 |
| 260 | + addi x16, x16, 96 |
| 261 | + /* End of loop */ |
| 262 | + |
| 263 | + ret |
0 commit comments