|
1 | 1 | #version 450 |
2 | 2 |
|
| 3 | +#extension GL_KHR_shader_subgroup_arithmetic : enable |
| 4 | +#extension GL_KHR_shader_subgroup_ballot : enable |
| 5 | +#extension GL_KHR_shader_subgroup_shuffle : enable |
3 | 6 | #include "rte.glsl" |
4 | 7 | #include "types.glsl" |
5 | 8 |
|
6 | | -#if defined(SET_ROWS) && QUANT_K == 1 |
| 9 | +#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0) |
| 10 | +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; |
| 11 | +const uint BLOCK_SIZE = 128; |
| 12 | +#elif defined(SET_ROWS) && QUANT_K == 1 |
7 | 13 | layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; |
8 | 14 | const uint BLOCK_SIZE = 512; |
9 | 15 | #else |
@@ -185,62 +191,67 @@ void quantize(uint dst_idx, uint src_idx) |
185 | 191 | #endif |
186 | 192 |
|
187 | 193 | #if defined(DATA_A_TURBO3_0) |
188 | | -void quantize(uint dst_idx, uint src_idx) |
189 | | -{ |
190 | | - const float centroids[8] = float[8]( |
191 | | - -0.190685, -0.117832, -0.065717, -0.021460, |
192 | | - 0.021460, 0.065717, 0.117832, 0.190685 |
193 | | - ); |
194 | | - const float midpoints[7] = float[7]( |
195 | | - -0.154259, -0.091775, -0.043589, 0.0, 0.043589, 0.091775, 0.154259 |
196 | | - ); |
197 | | - |
198 | | - // Compute L2 norm |
199 | | - float norm_sq = 0.0; |
200 | | - [[unroll]] for (int j = 0; j < 32; ++j) { |
201 | | - float v = data_s[src_idx + j]; |
202 | | - norm_sq += v * v; |
203 | | - } |
204 | | - float norm = sqrt(norm_sq); |
205 | | - float inv_norm = (norm > 1e-10) ? (1.0 / norm) : 0.0; |
206 | | - |
207 | | - // Clear output |
208 | | - [[unroll]] for (int j = 0; j < 8; ++j) data_q[dst_idx].qs[j] = uint8_t(0); |
209 | | - [[unroll]] for (int j = 0; j < 4; ++j) data_q[dst_idx].signs[j] = uint8_t(0); |
210 | | - |
211 | | - // Accumulate centroid reconstruction norm for correction |
212 | | - float recon_norm_sq = 0.0; |
213 | | - |
214 | | - // Quantize each element |
215 | | - [[unroll]] for (int j = 0; j < 32; ++j) { |
216 | | - float val = data_s[src_idx + j] * inv_norm; |
217 | | - |
218 | | - // Find nearest centroid via midpoint comparison |
219 | | - uint idx = 0; |
220 | | - if (val < midpoints[0]) idx = 0; |
221 | | - else if (val < midpoints[1]) idx = 1; |
222 | | - else if (val < midpoints[2]) idx = 2; |
223 | | - else if (val < midpoints[3]) idx = 3; |
224 | | - else if (val < midpoints[4]) idx = 4; |
225 | | - else if (val < midpoints[5]) idx = 5; |
226 | | - else if (val < midpoints[6]) idx = 6; |
227 | | - else idx = 7; |
228 | | - |
229 | | - recon_norm_sq += centroids[idx] * centroids[idx]; |
230 | | - |
231 | | - // Pack: low 2 bits to qs, high 1 bit to signs |
232 | | - uint low2 = idx & 0x3; |
233 | | - uint hi1 = (idx >> 2) & 0x1; |
234 | | - data_q[dst_idx].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2)); |
235 | | - data_q[dst_idx].signs[j / 8] |= uint8_t(hi1 << (j % 8)); |
236 | | - } |
| 194 | +const float TS1[128] = float[128]( |
| 195 | + -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, |
| 196 | + 1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1, |
| 197 | + -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, |
| 198 | + 1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, |
| 199 | + -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1, |
| 200 | + 1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1, |
| 201 | + -1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, |
| 202 | + 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1 |
| 203 | +); |
| 204 | + |
| 205 | +const float TS2[128] = float[128]( |
| 206 | + 1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1, |
| 207 | + 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1, |
| 208 | + 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, |
| 209 | + 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, |
| 210 | + 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1, |
| 211 | + -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, |
| 212 | + 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, |
| 213 | + -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1 |
| 214 | +); |
| 215 | + |
| 216 | +const float TINV = 0.08838834764831845; // 1 / sqrt(128) |
| 217 | + |
| 218 | +const float TC[8] = float[8]( |
| 219 | + -0.190685, -0.117832, -0.065717, -0.021460, |
| 220 | + 0.021460, 0.065717, 0.117832, 0.190685 |
| 221 | +); |
| 222 | + |
| 223 | +const float TM[7] = float[7]( |
| 224 | + -0.154259, -0.091775, -0.043589, |
| 225 | + 0.0, |
| 226 | + 0.043589, 0.091775, 0.154259 |
| 227 | +); |
237 | 228 |
|
238 | | - // Norm correction: scale so reconstruction matches original norm |
239 | | - float recon_norm = sqrt(recon_norm_sq); |
240 | | - float corrected_norm = (recon_norm > 1e-10) ? (norm / recon_norm) : norm; |
241 | | - data_q[dst_idx].norm = float16_t(corrected_norm); |
| 229 | +#if defined(SET_ROWS) |
| 230 | + |
| 231 | +shared float wht[128]; |
| 232 | +shared float sg_acc[16]; |
| 233 | +shared float gnrm; |
| 234 | + |
| 235 | +void quantize_block(uint b, uint o) { |
| 236 | + [[unroll]] for (int j = 0; j < 32; ++j) data_q[b].qs[j] = uint8_t(0); |
| 237 | + [[unroll]] for (int j = 0; j < 16; ++j) data_q[b].signs[j] = uint8_t(0); |
| 238 | + float rs = 0.0; |
| 239 | + [[unroll]] for (int j = 0; j < 128; ++j) { |
| 240 | + float v = wht[o + j]; |
| 241 | + uint i = v < TM[0] ? 0 : v < TM[1] ? 1 : v < TM[2] ? 2 : v < TM[3] ? 3 : |
| 242 | + v < TM[4] ? 4 : v < TM[5] ? 5 : v < TM[6] ? 6 : 7; |
| 243 | + rs += TC[i] * TC[i]; |
| 244 | + uint low2 = i & 0x3; |
| 245 | + uint hi1 = (i >> 2) & 0x1; |
| 246 | + data_q[b].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2)); |
| 247 | + data_q[b].signs[j / 8] |= uint8_t(hi1 << (j % 8)); |
| 248 | + } |
| 249 | + float rn = sqrt(rs); |
| 250 | + data_q[b].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm); |
242 | 251 | } |
243 | | -#endif |
| 252 | + |
| 253 | +#endif // defined(SET_ROWS) |
| 254 | +#endif // defined(DATA_A_TURBO3_0) |
244 | 255 |
|
245 | 256 | #if defined(DATA_A_IQ4_NL) |
246 | 257 | uint best_index(float x) { |
@@ -304,7 +315,97 @@ void quantize(uint dst_idx, uint src_idx) |
304 | 315 | } |
305 | 316 | #endif |
306 | 317 |
|
307 | | -#if defined(SET_ROWS) |
| 318 | +#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0) |
| 319 | +void main() { |
| 320 | + const uint t = gl_LocalInvocationID.x; |
| 321 | + const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; |
| 322 | + const uint gpr = p.ne00 / 128; |
| 323 | + |
| 324 | + if (gpr == 0) return; |
| 325 | + if (g >= p.ne / 128) return; |
| 326 | + |
| 327 | + uint tmp = g; |
| 328 | + const uint ig = tmp % gpr; tmp /= gpr; |
| 329 | + const uint i01 = tmp % p.ne01; tmp /= p.ne01; |
| 330 | + const uint i02 = tmp % p.ne12; |
| 331 | + const uint i03 = tmp / p.ne12; |
| 332 | + |
| 333 | + const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset(); |
| 334 | + const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE; |
| 335 | + const uint db = dst_idx(ig, i1, i02, i03) + get_doffset(); |
| 336 | + |
| 337 | + // Step 1: load into shared memory |
| 338 | + wht[t] = data_s[sb + t]; |
| 339 | + barrier(); |
| 340 | + |
| 341 | + // Step 2: L2 norm via subgroup reduction |
| 342 | + float v2 = wht[t] * wht[t]; |
| 343 | + v2 = subgroupAdd(v2); |
| 344 | + if (gl_SubgroupInvocationID == 0) sg_acc[gl_SubgroupID] = v2; |
| 345 | + barrier(); |
| 346 | + if (t == 0) { |
| 347 | + float total = 0.0; |
| 348 | + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w]; |
| 349 | + gnrm = sqrt(total); |
| 350 | + } |
| 351 | + barrier(); |
| 352 | + |
| 353 | + // Step 3: normalize, then apply forward WHT: signs1 -> butterfly -> signs2 |
| 354 | + wht[t] *= (gnrm > 1e-10) ? (1.0 / gnrm) : 0.0; |
| 355 | + barrier(); |
| 356 | + |
| 357 | + wht[t] *= TS1[t]; |
| 358 | + barrier(); |
| 359 | + |
| 360 | + [[unroll]] for (uint h = 1; h < 128; h *= 2) { |
| 361 | + if ((t % (2 * h)) < h) { |
| 362 | + float a = wht[t]; |
| 363 | + float b = wht[t + h]; |
| 364 | + wht[t] = a + b; |
| 365 | + wht[t + h] = a - b; |
| 366 | + } |
| 367 | + barrier(); |
| 368 | + } |
| 369 | + |
| 370 | + // Step 5: apply signs2 + scaling |
| 371 | + float rv = wht[t] * TINV * TS2[t]; |
| 372 | + |
| 373 | + // Step 6: quantize -- all 128 threads participate |
| 374 | + uint idx = rv < TM[0] ? 0u : rv < TM[1] ? 1u : rv < TM[2] ? 2u : rv < TM[3] ? 3u : |
| 375 | + rv < TM[4] ? 4u : rv < TM[5] ? 5u : rv < TM[6] ? 6u : 7u; |
| 376 | + |
| 377 | + // Pack qs: 4 elements per byte via subgroup shuffle |
| 378 | + uint sg_lane = gl_SubgroupInvocationID; |
| 379 | + uint my_low2 = idx & 0x3u; |
| 380 | + uint qs_byte = 0u; |
| 381 | + [[unroll]] for (uint k = 0; k < 4; k++) { |
| 382 | + uint contrib = subgroupShuffle(my_low2, (sg_lane & ~3u) + k); |
| 383 | + qs_byte |= contrib << (k * 2u); |
| 384 | + } |
| 385 | + if (sg_lane % 4u == 0u) { |
| 386 | + data_q[db].qs[t / 4u] = uint8_t(qs_byte); |
| 387 | + } |
| 388 | + |
| 389 | + // Pack signs: 8 elements per byte via subgroup ballot |
| 390 | + uvec4 ballot = subgroupBallot(((idx >> 2u) & 1u) != 0u); |
| 391 | + if (sg_lane % 8u == 0u) { |
| 392 | + uint local_byte = sg_lane / 8u; |
| 393 | + data_q[db].signs[t / 8u] = uint8_t((ballot.x >> (local_byte * 8u)) & 0xFFu); |
| 394 | + } |
| 395 | + |
| 396 | + // Step 7: reconstruction norm via subgroup reduction |
| 397 | + float rc = TC[idx] * TC[idx]; |
| 398 | + rc = subgroupAdd(rc); |
| 399 | + if (sg_lane == 0u) sg_acc[gl_SubgroupID] = rc; |
| 400 | + barrier(); |
| 401 | + if (t == 0u) { |
| 402 | + float total = 0.0; |
| 403 | + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w]; |
| 404 | + float rn = sqrt(total); |
| 405 | + data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm); |
| 406 | + } |
| 407 | +} |
| 408 | +#elif defined(SET_ROWS) |
308 | 409 |
|
309 | 410 | void main() { |
310 | 411 | #ifdef NEEDS_INIT_IQ_SHMEM |
|
0 commit comments