File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -333,6 +333,10 @@ static __global__ void flash_attn_ext_vec(
333333 }
334334
335335 // Sparse V: skip V dequant if all attention weights for this position are negligible
336+ // Disabled — per-lane branching causes warp divergence that costs more than the
337+ // skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
338+ // TODO: revisit with warp-level ballot skip.
339+ #if 0
336340 {
337341 bool dominated = true;
338342#pragma unroll
@@ -341,6 +345,7 @@ static __global__ void flash_attn_ext_vec(
341345 }
342346 if (dominated) { continue; }
343347 }
348+ #endif
344349
345350#pragma unroll
346351 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
@@ -373,6 +378,10 @@ static __global__ void flash_attn_ext_vec(
373378 }
374379
375380 // Sparse V: skip V dequant if all attention weights for this position are negligible
381+ // Disabled — per-lane branching causes warp divergence that costs more than the
382+ // skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
383+ // TODO: revisit with warp-level ballot skip.
384+ #if 0
376385 {
377386 bool dominated = true;
378387#pragma unroll
@@ -381,6 +390,7 @@ static __global__ void flash_attn_ext_vec(
381390 }
382391 if (dominated) { continue; }
383392 }
393+ #endif
384394
385395#pragma unroll
386396 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
You can’t perform that action at this time.
0 commit comments