Skip to content

Commit 11a241d

Browse files
authored
Merge pull request #105 from TheTom/fix/disable-sparse-v-cuda
cuda: disable sparse V skip (warp divergence regression)
2 parents 67559e5 + f2dc968 commit 11a241d

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,10 @@ static __global__ void flash_attn_ext_vec(
333333
}
334334

335335
// Sparse V: skip V dequant if all attention weights for this position are negligible
336+
// Disabled — per-lane branching causes warp divergence that costs more than the
337+
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
338+
// TODO: revisit with warp-level ballot skip.
339+
#if 0
336340
{
337341
bool dominated = true;
338342
#pragma unroll
@@ -341,6 +345,7 @@ static __global__ void flash_attn_ext_vec(
341345
}
342346
if (dominated) { continue; }
343347
}
348+
#endif
344349

345350
#pragma unroll
346351
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
@@ -373,6 +378,10 @@ static __global__ void flash_attn_ext_vec(
373378
}
374379

375380
// Sparse V: skip V dequant if all attention weights for this position are negligible
381+
// Disabled — per-lane branching causes warp divergence that costs more than the
382+
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
383+
// TODO: revisit with warp-level ballot skip.
384+
#if 0
376385
{
377386
bool dominated = true;
378387
#pragma unroll
@@ -381,6 +390,7 @@ static __global__ void flash_attn_ext_vec(
381390
}
382391
if (dominated) { continue; }
383392
}
393+
#endif
384394

385395
#pragma unroll
386396
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {

0 commit comments

Comments
 (0)