2222
2323#include "types.glsl"
2424#include "flash_attn_base.glsl"
25+ #include "flash_attn_dequant.glsl"
2526
2627const uint32_t HSK_per_thread = HSK / D_split;
2728const uint32_t HSV_per_thread = HSV / D_split;
@@ -130,18 +131,20 @@ void main() {
130131
131132 Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
132133
133- #if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
134- if (buf_iqs == 0) {
135- Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
136- }
137- #else // Q4_0, Q4_1, Q5_0, Q5_1
138- const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
139- const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
134+ // Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores
135+ // the row-sum scaled by qd, used in k_dot_correction.
136+ if (FaTypeK == FA_TYPE_Q8_0) {
137+ if (buf_iqs == 0) {
138+ Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
139+ }
140+ } else {
141+ const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
142+ const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
140143
141- if (buf_iqs == 0) {
142- Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
144+ if (buf_iqs == 0) {
145+ Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
146+ }
143147 }
144- #endif
145148#endif
146149 }
147150 barrier();
@@ -179,13 +182,9 @@ void main() {
179182 // mo_offset will point to the tile starting at row i*Br and col 0
180183 uint32_t mo_offset = mo_stride * i;
181184
182- #if BLOCK_SIZE > 1
183- uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
184- uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
185- #else
186- uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
187- uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
188- #endif
185+ // FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants.
186+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
187+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
189188 uint32_t m_offset = gqa_iq1*KV;
190189 if (p.nem2 != 1 || p.nem3 != 1) {
191190 m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
@@ -259,21 +258,21 @@ void main() {
259258 if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
260259 FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
261260 if (!KV_bounds_check || j * Bc + c < KV) {
262- #if BLOCK_SIZE > 1
263- uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
264- uint ib = coord / BLOCK_SIZE ;
265- uint iqs = (coord % BLOCK_SIZE );
266- K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
267- # else
268- K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
269- #endif
261+ if (USE_DECODE_K) {
262+ uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
263+ uint ib = coord / BLOCK_SIZE_K ;
264+ uint iqs = (coord % BLOCK_SIZE_K );
265+ K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
266+ } else {
267+ K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
268+ }
270269 }
271270
272271 kvsh[c * kvsh_stride + d] = K_Tf;
273272 }
274273 }
275274#else // MMQ
276- const uint ints_per_block = 8 / QUANT_R_MMQ ;
275+ const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK) ;
277276 const uint quant_iters = Bc * HSK / 32 * ints_per_block;
278277 [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
279278 const uint32_t iqs = (idx + tid) % ints_per_block;
@@ -312,15 +311,13 @@ void main() {
312311 FLOAT_TYPEV4 K_Tf;
313312 if (SHMEM_STAGING != 0) {
314313 K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
315- } else {
316- #if BLOCK_SIZE > 1
317- uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
318- uint ib = coord / BLOCK_SIZE;
319- uint iqs = (coord % BLOCK_SIZE);
314+ } else if (USE_DECODE_K) {
315+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
316+ uint ib = coord / BLOCK_SIZE_K;
317+ uint iqs = (coord % BLOCK_SIZE_K);
320318 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
321- # else
319+ } else {
322320 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
323- #endif
324321 }
325322 [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
326323 Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
@@ -337,15 +334,13 @@ void main() {
337334 FLOAT_TYPEV4 K_Tf;
338335 if (SHMEM_STAGING != 0) {
339336 K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
340- } else {
341- #if BLOCK_SIZE > 1
342- uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
343- uint ib = coord / BLOCK_SIZE;
344- uint iqs = (coord % BLOCK_SIZE);
337+ } else if (USE_DECODE_K) {
338+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
339+ uint ib = coord / BLOCK_SIZE_K;
340+ uint iqs = (coord % BLOCK_SIZE_K);
345341 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
346- # else
342+ } else {
347343 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
348- #endif
349344 }
350345 [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
351346 Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
@@ -368,72 +363,47 @@ void main() {
368363 int32_t k_quants[d_per_step];
369364 ACC_TYPEV2 k_dm;
370365
366+ // Q4_*/Q5_* take the block-8 fast path when one step covers a full
367+ // block; Q8_0 always goes through the per-int get_k_qs* helpers
368+ // (its qs is byte-packed, not nibble-packed).
369+ const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0);
370+
371371 if (SHMEM_STAGING != 0) {
372372 const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
373373 const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
374- #if QUANT_AUXF == 1
375- k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
376- #else
377374 k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
378- #endif
379375
380- #if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
381- if (d_per_step == 8) {
376+ if (block8_fast) {
377+ const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
382378 [[unroll]] for (uint32_t d = 0; d < 4; d++) {
383379 uint vui = kblocksh[buf_ib].qs[d];
384380 k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
385381 k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
386- #if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
387- uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
388- uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
389- k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
390- k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
391- #endif
382+ if (has_qh) {
383+ uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
384+ uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
385+ k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
386+ k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
387+ }
392388 }
393- } else
394- #endif
395- {
389+ } else {
396390 [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
397391 k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
398392 }
399393 }
400394 } else {
401- const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
402- const uint ib = coord / BLOCK_SIZE ;
403- const uint iqs = (coord % BLOCK_SIZE );
395+ const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
396+ const uint ib = coord / BLOCK_SIZE_K ;
397+ const uint iqs = (coord % BLOCK_SIZE_K );
404398
405- #if QUANT_AUXF == 1
406- k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
407- #else
408- k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
409- #endif
410- #if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
411- if (d_per_step == 8) {
412- #if defined(DATA_A_Q5_0)
413- uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
414- k_packed.k_data_packed16[k_offset + ib].qh[1]));
415- #elif defined(DATA_A_Q5_1)
416- uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
417- #endif
418- [[unroll]] for (uint32_t d = 0; d < 4; d++) {
419- #if defined(A_TYPE_PACKED32)
420- uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
421- #else
422- uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
423- k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
424- #endif
425- k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
426- k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
427- #if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
428- uint qh_lo = (qh >> (d * 4)) & 0xF;
429- uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
430- k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
431- k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
432- #endif
399+ k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset));
400+
401+ if (block8_fast) {
402+ fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset);
403+ [[unroll]] for (uint32_t d = 0; d < 8; d++) {
404+ k_quants[d] = blk.qs[d];
433405 }
434- } else
435- #endif
436- {
406+ } else {
437407 [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
438408 k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
439409 }
@@ -518,14 +488,14 @@ void main() {
518488 if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
519489 FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
520490 if (!KV_bounds_check || j * Bc + c < KV) {
521- #if BLOCK_SIZE > 1
522- uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
523- uint ib = coord / BLOCK_SIZE ;
524- uint iqs = (coord % BLOCK_SIZE );
525- V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
526- # else
527- V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
528- #endif
491+ if (USE_DECODE_V) {
492+ uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
493+ uint ib = coord / BLOCK_SIZE_V ;
494+ uint iqs = (coord % BLOCK_SIZE_V );
495+ V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
496+ } else {
497+ V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
498+ }
529499 }
530500
531501 kvsh[c * kvsh_stride + d] = V_Tf;
@@ -549,15 +519,13 @@ void main() {
549519 FLOAT_TYPEV4 Vf;
550520 if (SHMEM_STAGING != 0) {
551521 Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
552- } else {
553- #if BLOCK_SIZE > 1
554- uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
555- uint ib = coord / BLOCK_SIZE;
556- uint iqs = (coord % BLOCK_SIZE);
522+ } else if (USE_DECODE_V) {
523+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid);
524+ uint ib = coord / BLOCK_SIZE_V;
525+ uint iqs = (coord % BLOCK_SIZE_V);
557526 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
558- # else
527+ } else {
559528 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
560- #endif
561529 }
562530 [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
563531 Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
0 commit comments