Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
#include "fattn-common.cuh"
#include "fattn-wmma-f16.cuh"

// Tile-level sparse V skip (TurboQuant). Off by default; opt in by defining
// GGML_CUDA_TURBO_SPARSE_V_TILE at build time. Threshold defaults to 0.001
// (matches vllm-project/vllm#41422 — bit-identical PPL + NIAH all-pass on
// Qwen3-8B at 32K). Override with -DGGML_CUDA_TURBO_SPARSE_V_THRESHOLD=<val>.
#if defined(GGML_CUDA_TURBO_SPARSE_V_TILE) && !defined(GGML_CUDA_TURBO_SPARSE_V_THRESHOLD)
#define GGML_CUDA_TURBO_SPARSE_V_THRESHOLD 0.001f
#endif

// nbatch_fa == number of KQ rows to process per iteration
// nbatch_K == number of K columns to load in parallel for KQ calculation

Expand Down Expand Up @@ -642,6 +650,19 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
}

// Tile-level sparse V skip (TurboQuant): cheaply track the max softmax
// probability seen across this FA tile so we can skip the whole V matmul
// block when no position contributes meaningfully. Opt-in via build flag
// -DGGML_CUDA_TURBO_SPARSE_V_TILE -DGGML_CUDA_TURBO_SPARSE_V_THRESHOLD=0.001f.
// The decision is block-uniform (all threads take the same branch via
// shared-memory reduction), avoiding the per-lane warp divergence that
// motivated the earlier April 24 revert (commit f2dc968).
// Mirror of vllm-project/vllm#41422 which validated +7.13% decode @ 32K
// on AMD MI300X with PPL bit-identical and NIAH all-pass.
#ifdef GGML_CUDA_TURBO_SPARSE_V_TILE
float thread_max_val = 0.0f;
#endif

// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
#pragma unroll
for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
Expand All @@ -665,6 +686,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val;
#ifdef GGML_CUDA_TURBO_SPARSE_V_TILE
thread_max_val = fmaxf(thread_max_val, val);
#endif
}
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;

Expand Down Expand Up @@ -693,12 +717,32 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
}
}

#ifdef GGML_CUDA_TURBO_SPARSE_V_TILE
// Block-level reduction of thread_max_val. Per-warp shfl reduce, write to
// shared, syncthreads, fan back out — leaves block_max_val identical and
// warp-uniform across all threads. Cost: one __syncthreads + cheap fp ops.
thread_max_val = warp_reduce_max<warp_size>(thread_max_val);
__shared__ float sparse_v_warp_max[nwarps];
if (threadIdx.x == 0) sparse_v_warp_max[threadIdx.y] = thread_max_val;
__syncthreads();
float block_max_val = 0.0f;
#pragma unroll
for (int w = 0; w < nwarps; ++w) {
block_max_val = fmaxf(block_max_val, sparse_v_warp_max[w]);
}
constexpr float sparse_v_threshold = GGML_CUDA_TURBO_SPARSE_V_THRESHOLD;
const bool skip_v_matmul = block_max_val < sparse_v_threshold;
#endif

// VKQ = V @ KQ matrix multiplication:
static_assert(DV <= DKQ, "bad DV");
static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
static_assert(nbatch_V % np == 0, "bad nbatch_V");
#ifdef GGML_CUDA_TURBO_SPARSE_V_TILE
if (!skip_v_matmul) {
#endif
#pragma unroll
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
Expand Down Expand Up @@ -769,6 +813,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(

__syncthreads();
}
#ifdef GGML_CUDA_TURBO_SPARSE_V_TILE
} // close if (!skip_v_matmul)
#endif
}

template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
Expand Down
58 changes: 50 additions & 8 deletions ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
#include "common.cuh"
#include "fattn-common.cuh"

// Tile-uniform sparse V skip in the VEC kernel (TurboQuant). Off by default;
// opt in by defining GGML_CUDA_TURBO_SPARSE_V_VEC at build time. Threshold
// defaults to 0.001 (matches vllm-project/vllm#41422 — bit-identical PPL +
// NIAH all-pass on Qwen3-8B at 32K). Override with
// -DGGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD=<val>.
#if defined(GGML_CUDA_TURBO_SPARSE_V_VEC) && !defined(GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD)
#define GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD 0.001f
#endif

static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
return 128;
GGML_UNUSED(cc);
Expand Down Expand Up @@ -412,12 +421,33 @@ static __global__ void flash_attn_ext_vec(
}
}

// Sparse V: skip V dequant if all attention weights for this position are negligible.
// For turbo types, the check is compiled out: at typical decode context lengths
// (< ~4K tokens) with threshold 1e-6, no positions are ever skipped, so the
// per-position branch is pure overhead (misprediction + comparison cost). This
// also dodges the warp-divergence regression on turbo paths that motivated the
// April 24 revert (commit f2dc968).
// Sparse V skip — two strategies:
//
// GGML_CUDA_TURBO_SPARSE_V_VEC (opt-in, default off):
// Warp-uniform skip via warp_reduce_max — all lanes branch on
// the same value so no warp divergence. Works on every V type
// including turbo. Threshold defaults to 0.001 (matches the
// vllm-project/vllm#41422 design that validated +7.13% decode
// at 32K on AMD MI300X with PPL bit-identical and NIAH all-
// pass).
//
// default (signalnine PR #115):
// Per-lane skip with `if constexpr (!V_is_turbo)` compile-time
// gate. Compiled out for turbo to dodge the warp-divergence
// regression that motivated the April 24 revert (commit
// f2dc968). Kept as the default while the warp-uniform variant
// is bench-validated cross-platform.
#ifdef GGML_CUDA_TURBO_SPARSE_V_VEC
{
float my_kq_max = 0.0f;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
my_kq_max = fmaxf(my_kq_max, __half2float(__low2half(KQ_k[j])));
}
const float warp_max = warp_reduce_max(my_kq_max);
if (warp_max < (float) GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD) continue;
}
#else
if constexpr (!V_is_turbo) {
bool dominated = true;
#pragma unroll
Expand All @@ -426,6 +456,7 @@ static __global__ void flash_attn_ext_vec(
}
if (dominated) { continue; }
}
#endif

#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
Expand Down Expand Up @@ -461,8 +492,18 @@ static __global__ void flash_attn_ext_vec(
}
}

// Sparse V: skip V dequant if all attention weights for this position are negligible.
// Compiled out for turbo types — see half2 path comment above.
// Sparse V skip — see half2-path comment above. Same two strategies.
#ifdef GGML_CUDA_TURBO_SPARSE_V_VEC
{
float my_kq_max = 0.0f;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
my_kq_max = fmaxf(my_kq_max, KQ_k[j]);
}
const float warp_max = warp_reduce_max(my_kq_max);
if (warp_max < (float) GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD) continue;
}
#else
if constexpr (!V_is_turbo) {
bool dominated = true;
#pragma unroll
Expand All @@ -471,6 +512,7 @@ static __global__ void flash_attn_ext_vec(
}
if (dominated) { continue; }
}
#endif

// Turbo V path: precompute scaled centroids once per block to eliminate
// per-element norm multiply. centroid[idx]*norm is computed 8/4/16 times
Expand Down
Loading