Skip to content

Commit 88e22ab

Browse files
jeffbolznvJcfunk
authored andcommitted
vulkan: Support asymmetric FA in scalar/mmq/coopmat1 paths (ggml-org#22589)
1 parent e38c3b9 commit 88e22ab

7 files changed

Lines changed: 482 additions & 359 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
2727
#if __has_include(<spirv/unified1/spirv.hpp>)
2828
# include <spirv/unified1/spirv.hpp>
2929
#elif __has_include(<spirv-headers/spirv.hpp>)
30-
#include <spirv-headers/spirv.hpp>
30+
# include <spirv-headers/spirv.hpp>
3131
#elif __has_include(<spirv.hpp>)
3232
# include <spirv.hpp>
3333
#else
3434
// Fallback to let the compiler throw a standard "file not found" error
35-
#include <spirv/unified1/spirv.hpp>
35+
# include <spirv/unified1/spirv.hpp>
3636
#endif
3737

3838
#include <algorithm>
@@ -4517,12 +4517,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
45174517
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
45184518
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_rte_len, cpy_f32_turbo3_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
45194519
} else {
4520-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4521-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4522-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4523-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4524-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4525-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4520+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4521+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4522+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4523+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4524+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4525+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
45264526
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
45274527
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_len, cpy_f32_turbo3_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
45284528
}
@@ -15632,21 +15632,21 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1563215632
auto fa_kv_ok = [coopmat2](ggml_type t) {
1563315633
switch (t) {
1563415634
case GGML_TYPE_F32:
15635-
case GGML_TYPE_F16:
15636-
case GGML_TYPE_Q8_0:
15635+
case GGML_TYPE_F16:
15636+
case GGML_TYPE_Q8_0:
1563715637
case GGML_TYPE_TURBO3_0:
1563815638
// supported in scalar and coopmat2 paths
1563915639
break;
1564015640
case GGML_TYPE_Q5_1:
1564115641
case GGML_TYPE_Q5_0:
15642-
case GGML_TYPE_Q4_1:
15642+
case GGML_TYPE_Q4_1:
1564315643
case GGML_TYPE_Q4_0:
1564415644
return true;
1564515645
case GGML_TYPE_Q1_0:
1564615646
return coopmat2;
15647-
default:
15648-
return false;
15649-
}
15647+
default:
15648+
return false;
15649+
}
1565015650
};
1565115651
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
1565215652
return false;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 72 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "types.glsl"
2424
#include "flash_attn_base.glsl"
25+
#include "flash_attn_dequant.glsl"
2526

2627
const uint32_t HSK_per_thread = HSK / D_split;
2728
const uint32_t HSV_per_thread = HSV / D_split;
@@ -130,18 +131,20 @@ void main() {
130131

131132
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
132133

133-
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
134-
if (buf_iqs == 0) {
135-
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
136-
}
137-
#else // Q4_0, Q4_1, Q5_0, Q5_1
138-
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
139-
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
134+
// Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores
135+
// the row-sum scaled by qd, used in k_dot_correction.
136+
if (FaTypeK == FA_TYPE_Q8_0) {
137+
if (buf_iqs == 0) {
138+
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
139+
}
140+
} else {
141+
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
142+
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
140143

141-
if (buf_iqs == 0) {
142-
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
144+
if (buf_iqs == 0) {
145+
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
146+
}
143147
}
144-
#endif
145148
#endif
146149
}
147150
barrier();
@@ -179,13 +182,9 @@ void main() {
179182
// mo_offset will point to the tile starting at row i*Br and col 0
180183
uint32_t mo_offset = mo_stride * i;
181184

182-
#if BLOCK_SIZE > 1
183-
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
184-
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
185-
#else
186-
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
187-
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
188-
#endif
185+
// FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants.
186+
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK;
187+
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV;
189188
uint32_t m_offset = gqa_iq1*KV;
190189
if (p.nem2 != 1 || p.nem3 != 1) {
191190
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
@@ -259,21 +258,21 @@ void main() {
259258
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
260259
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
261260
if (!KV_bounds_check || j * Bc + c < KV) {
262-
#if BLOCK_SIZE > 1
263-
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
264-
uint ib = coord / BLOCK_SIZE;
265-
uint iqs = (coord % BLOCK_SIZE);
266-
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
267-
#else
268-
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
269-
#endif
261+
if (USE_DECODE_K) {
262+
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d;
263+
uint ib = coord / BLOCK_SIZE_K;
264+
uint iqs = (coord % BLOCK_SIZE_K);
265+
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
266+
} else {
267+
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
268+
}
270269
}
271270

272271
kvsh[c * kvsh_stride + d] = K_Tf;
273272
}
274273
}
275274
#else // MMQ
276-
const uint ints_per_block = 8 / QUANT_R_MMQ;
275+
const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK);
277276
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
278277
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
279278
const uint32_t iqs = (idx + tid) % ints_per_block;
@@ -312,15 +311,13 @@ void main() {
312311
FLOAT_TYPEV4 K_Tf;
313312
if (SHMEM_STAGING != 0) {
314313
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
315-
} else {
316-
#if BLOCK_SIZE > 1
317-
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
318-
uint ib = coord / BLOCK_SIZE;
319-
uint iqs = (coord % BLOCK_SIZE);
314+
} else if (USE_DECODE_K) {
315+
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
316+
uint ib = coord / BLOCK_SIZE_K;
317+
uint iqs = (coord % BLOCK_SIZE_K);
320318
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
321-
#else
319+
} else {
322320
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
323-
#endif
324321
}
325322
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
326323
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
@@ -337,15 +334,13 @@ void main() {
337334
FLOAT_TYPEV4 K_Tf;
338335
if (SHMEM_STAGING != 0) {
339336
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
340-
} else {
341-
#if BLOCK_SIZE > 1
342-
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
343-
uint ib = coord / BLOCK_SIZE;
344-
uint iqs = (coord % BLOCK_SIZE);
337+
} else if (USE_DECODE_K) {
338+
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid);
339+
uint ib = coord / BLOCK_SIZE_K;
340+
uint iqs = (coord % BLOCK_SIZE_K);
345341
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
346-
#else
342+
} else {
347343
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
348-
#endif
349344
}
350345
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
351346
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
@@ -368,72 +363,47 @@ void main() {
368363
int32_t k_quants[d_per_step];
369364
ACC_TYPEV2 k_dm;
370365

366+
// Q4_*/Q5_* take the block-8 fast path when one step covers a full
367+
// block; Q8_0 always goes through the per-int get_k_qs* helpers
368+
// (its qs is byte-packed, not nibble-packed).
369+
const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0);
370+
371371
if (SHMEM_STAGING != 0) {
372372
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
373373
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
374-
#if QUANT_AUXF == 1
375-
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
376-
#else
377374
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
378-
#endif
379375

380-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
381-
if (d_per_step == 8) {
376+
if (block8_fast) {
377+
const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1);
382378
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
383379
uint vui = kblocksh[buf_ib].qs[d];
384380
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
385381
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
386-
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
387-
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
388-
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
389-
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
390-
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
391-
#endif
382+
if (has_qh) {
383+
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
384+
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
385+
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
386+
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
387+
}
392388
}
393-
} else
394-
#endif
395-
{
389+
} else {
396390
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
397391
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
398392
}
399393
}
400394
} else {
401-
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
402-
const uint ib = coord / BLOCK_SIZE;
403-
const uint iqs = (coord % BLOCK_SIZE);
395+
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
396+
const uint ib = coord / BLOCK_SIZE_K;
397+
const uint iqs = (coord % BLOCK_SIZE_K);
404398

405-
#if QUANT_AUXF == 1
406-
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
407-
#else
408-
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
409-
#endif
410-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
411-
if (d_per_step == 8) {
412-
#if defined(DATA_A_Q5_0)
413-
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
414-
k_packed.k_data_packed16[k_offset + ib].qh[1]));
415-
#elif defined(DATA_A_Q5_1)
416-
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
417-
#endif
418-
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
419-
#if defined(A_TYPE_PACKED32)
420-
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
421-
#else
422-
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
423-
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
424-
#endif
425-
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
426-
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
427-
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
428-
uint qh_lo = (qh >> (d * 4)) & 0xF;
429-
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
430-
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
431-
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
432-
#endif
399+
k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset));
400+
401+
if (block8_fast) {
402+
fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset);
403+
[[unroll]] for (uint32_t d = 0; d < 8; d++) {
404+
k_quants[d] = blk.qs[d];
433405
}
434-
} else
435-
#endif
436-
{
406+
} else {
437407
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
438408
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
439409
}
@@ -518,14 +488,14 @@ void main() {
518488
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
519489
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
520490
if (!KV_bounds_check || j * Bc + c < KV) {
521-
#if BLOCK_SIZE > 1
522-
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
523-
uint ib = coord / BLOCK_SIZE;
524-
uint iqs = (coord % BLOCK_SIZE);
525-
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
526-
#else
527-
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
528-
#endif
491+
if (USE_DECODE_V) {
492+
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d;
493+
uint ib = coord / BLOCK_SIZE_V;
494+
uint iqs = (coord % BLOCK_SIZE_V);
495+
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
496+
} else {
497+
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
498+
}
529499
}
530500

531501
kvsh[c * kvsh_stride + d] = V_Tf;
@@ -549,15 +519,13 @@ void main() {
549519
FLOAT_TYPEV4 Vf;
550520
if (SHMEM_STAGING != 0) {
551521
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
552-
} else {
553-
#if BLOCK_SIZE > 1
554-
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
555-
uint ib = coord / BLOCK_SIZE;
556-
uint iqs = (coord % BLOCK_SIZE);
522+
} else if (USE_DECODE_V) {
523+
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid);
524+
uint ib = coord / BLOCK_SIZE_V;
525+
uint iqs = (coord % BLOCK_SIZE_V);
557526
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
558-
#else
527+
} else {
559528
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
560-
#endif
561529
}
562530
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
563531
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);

0 commit comments

Comments
 (0)