Skip to content

Commit a07ab09

Browse files
TheTomclaude
andcommitted
sparse V VEC: warp-uniform skip via warp_reduce_max (opt-in)
Adds the vllm-project/vllm#41422 design pattern to llama.cpp's CUDA fattn-vec kernel: replace the per-lane sparse V skip (which had warp divergence on turbo paths and was compile-time gated off for turbo via PR #115's `if constexpr (!V_is_turbo)`) with a warp-uniform skip via `warp_reduce_max`. All lanes branch on the same value so there's no warp divergence regardless of V type. Off by default. Opt in at build time: cmake -DCMAKE_CXX_FLAGS=-DGGML_CUDA_TURBO_SPARSE_V_VEC Threshold defaults to 0.001f (matches vLLM PR #41422). Override with -DGGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD=<val>. Default-off path is byte-identical (verified on M5 Max Metal: Qwen2.5-7B Q8_0 sym turbo3 PPL 6.6594, exact match with PR #115 baseline). Pairs with the prior commit's tile-kernel sparse V skip (off-by-default opt-in via GGML_CUDA_TURBO_SPARSE_V_TILE) — that one targets fattn-tile prefill, this one targets fattn-vec decode (the actual hot path on single-token generation where the vLLM win was measured). NO MERGE — testing branch only. AMD MI300X HIP validation pending. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 289b9dc commit a07ab09

1 file changed

Lines changed: 50 additions & 8 deletions

File tree

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
#include "common.cuh"
22
#include "fattn-common.cuh"
33

4+
// Tile-uniform sparse V skip in the VEC kernel (TurboQuant). Off by default;
5+
// opt in by defining GGML_CUDA_TURBO_SPARSE_V_VEC at build time. Threshold
6+
// defaults to 0.001 (matches vllm-project/vllm#41422 — bit-identical PPL +
7+
// NIAH all-pass on Qwen3-8B at 32K). Override with
8+
// -DGGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD=<val>.
9+
#if defined(GGML_CUDA_TURBO_SPARSE_V_VEC) && !defined(GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD)
10+
#define GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD 0.001f
11+
#endif
12+
413
static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
514
return 128;
615
GGML_UNUSED(cc);
@@ -412,12 +421,33 @@ static __global__ void flash_attn_ext_vec(
412421
}
413422
}
414423

415-
// Sparse V: skip V dequant if all attention weights for this position are negligible.
416-
// For turbo types, the check is compiled out: at typical decode context lengths
417-
// (< ~4K tokens) with threshold 1e-6, no positions are ever skipped, so the
418-
// per-position branch is pure overhead (misprediction + comparison cost). This
419-
// also dodges the warp-divergence regression on turbo paths that motivated the
420-
// April 24 revert (commit f2dc968).
424+
// Sparse V skip — two strategies:
425+
//
426+
// GGML_CUDA_TURBO_SPARSE_V_VEC (opt-in, default off):
427+
// Warp-uniform skip via warp_reduce_max — all lanes branch on
428+
// the same value so no warp divergence. Works on every V type
429+
// including turbo. Threshold defaults to 0.001 (matches the
430+
// vllm-project/vllm#41422 design that validated +7.13% decode
431+
// at 32K on AMD MI300X with PPL bit-identical and NIAH all-
432+
// pass).
433+
//
434+
// default (signalnine PR #115):
435+
// Per-lane skip with `if constexpr (!V_is_turbo)` compile-time
436+
// gate. Compiled out for turbo to dodge the warp-divergence
437+
// regression that motivated the April 24 revert (commit
438+
// f2dc968). Kept as the default while the warp-uniform variant
439+
// is bench-validated cross-platform.
440+
#ifdef GGML_CUDA_TURBO_SPARSE_V_VEC
441+
{
442+
float my_kq_max = 0.0f;
443+
#pragma unroll
444+
for (int j = 0; j < ncols; ++j) {
445+
my_kq_max = fmaxf(my_kq_max, __half2float(__low2half(KQ_k[j])));
446+
}
447+
const float warp_max = warp_reduce_max(my_kq_max);
448+
if (warp_max < (float) GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD) continue;
449+
}
450+
#else
421451
if constexpr (!V_is_turbo) {
422452
bool dominated = true;
423453
#pragma unroll
@@ -426,6 +456,7 @@ static __global__ void flash_attn_ext_vec(
426456
}
427457
if (dominated) { continue; }
428458
}
459+
#endif
429460

430461
#pragma unroll
431462
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
@@ -461,8 +492,18 @@ static __global__ void flash_attn_ext_vec(
461492
}
462493
}
463494

464-
// Sparse V: skip V dequant if all attention weights for this position are negligible.
465-
// Compiled out for turbo types — see half2 path comment above.
495+
// Sparse V skip — see half2-path comment above. Same two strategies.
496+
#ifdef GGML_CUDA_TURBO_SPARSE_V_VEC
497+
{
498+
float my_kq_max = 0.0f;
499+
#pragma unroll
500+
for (int j = 0; j < ncols; ++j) {
501+
my_kq_max = fmaxf(my_kq_max, KQ_k[j]);
502+
}
503+
const float warp_max = warp_reduce_max(my_kq_max);
504+
if (warp_max < (float) GGML_CUDA_TURBO_SPARSE_V_VEC_THRESHOLD) continue;
505+
}
506+
#else
466507
if constexpr (!V_is_turbo) {
467508
bool dominated = true;
468509
#pragma unroll
@@ -471,6 +512,7 @@ static __global__ void flash_attn_ext_vec(
471512
}
472513
if (dominated) { continue; }
473514
}
515+
#endif
474516

475517
// Turbo V path: precompute scaled centroids once per block to eliminate
476518
// per-element norm multiply. centroid[idx]*norm is computed 8/4/16 times

0 commit comments

Comments
 (0)