Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ static __global__ void flash_attn_ext_vec(
}

// Sparse V: skip V dequant if all attention weights for this position are negligible
// Disabled — per-lane branching causes warp divergence that costs more than the
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
// TODO: revisit with warp-level ballot skip.
#if 0
{
bool dominated = true;
#pragma unroll
Expand All @@ -341,6 +345,7 @@ static __global__ void flash_attn_ext_vec(
}
if (dominated) { continue; }
}
#endif

#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
Expand Down Expand Up @@ -373,6 +378,10 @@ static __global__ void flash_attn_ext_vec(
}

// Sparse V: skip V dequant if all attention weights for this position are negligible
// Disabled — per-lane branching causes warp divergence that costs more than the
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
// TODO: revisit with warp-level ballot skip.
#if 0
{
bool dominated = true;
#pragma unroll
Expand All @@ -381,6 +390,7 @@ static __global__ void flash_attn_ext_vec(
}
if (dominated) { continue; }
}
#endif

#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
Expand Down
Loading