Skip to content

HIP/turbo3: graph-safe decode + inline-dequant TILE prefill on gfx1201 (RDNA4)#28

Open
KaiFelixBennett wants to merge 1 commit into
AtomicBot-ai:feature/turboquant-kv-cachefrom
KaiFelixBennett:feat/turbo3-rocm-graphsafe-and-inline-tile-prefill
Open

HIP/turbo3: graph-safe decode + inline-dequant TILE prefill on gfx1201 (RDNA4)#28
KaiFelixBennett wants to merge 1 commit into
AtomicBot-ai:feature/turboquant-kv-cachefrom
KaiFelixBennett:feat/turbo3-rocm-graphsafe-and-inline-tile-prefill

Conversation

@KaiFelixBennett

Copy link
Copy Markdown

What

Makes TurboQuant (turbo3) KV cache fully usable on ROCm / HIP graphs on gfx1201 (RDNA4) — both prefill and decode — and brings turbo3 prefill up to roughly f16 speed.

Measured on a Radeon AI PRO R9700 (gfx1201, 32 GB), Windows 11, HIP SDK 7.1, Gemma-4 Q4_K_M, HIP graphs ON. Nothing extrapolated.

Two coupled changes

1. Graph-safe decode (fixes the crash)

With HIP graphs on, turbo KV crashed on the first decode step:

FLASH_ATTN_EXT failed: operation not permitted when stream is capturing

Cause: launch_fattn's f16 dequant temp buffers (K_f16/V_f16) used raw cudaMalloc/cudaFree during graph capture, which is illegal. Fix: capture-aware allocation — pool alloc while capturing / for small batches, raw alloc+free for large eager prefill so VRAM stays bounded on the no-VMM card — and decode (Q->ne[1] <= 2) routes to the graph-safe VEC kernel (inline dequant, no temp buffer).

This is the same class of decode crash we fixed canonically upstream in TheTom#176 (merged, 7985f6b). This PR adapts that fix to this fork's newer base.

2. Inline-dequant TILE prefill (makes prefill fast)

The TILE/MMA path hardcoded need_f16_K/V = true and materialized the whole KV cache to an f16 temp buffer every step — a per-step O(KV) dequant tax that negates the 3-bit cache, so turbo3 prefill was stuck on the slow sequential VEC kernel. A new TILE path inline-dequantizes turbo3 K/V during the global→shared tile load (no f16 materialization); turbo3 head_dim=256 multi-row batches (Q->ne[1] >= 3, prefill + spec-verify) route to it.

prefill turbo3 VEC (stock) turbo3 TILE (this) speedup f16 ref
pp512 1570 t/s 2187 1.39× 2174
pp2048 1038 t/s 1929 1.86× 2049
pp4096 1039 t/s 1752 1.69× 1984

turbo3 prefill is now within ~6–12% of f16, and the gain grows with context length — the long-context regime where the 3-bit cache earns its keep.

Correctness

test-backend-ops -o FLASH_ATTN_EXT, turbo3 hsk=256, nb ∈ {3,4,6 (verify), 64,128,256 (prefill)}, kv ∈ {512,1024}: 12/12 OK (NMSE within tolerance vs the CPU reference).

Scope / honesty

  • Validated for turbo3/turbo3 at head_dim=256 (Gemma-4 family). Other dims/types keep the f16 path.
  • FlashAttention path only. It does not make self-speculative MTP beat baseline decode on RDNA4 (that wall is GEMM weight-load amortization, not attention).
  • Full methodology, raw data, correctness logs and a one-command gfx1201 build: https://github.com/KaiFelixBennett/gemma4-turboquant-rdna4

…1 (RDNA4)

Make TurboQuant (turbo3) KV cache usable on ROCm with HIP graphs on gfx1201,
both prefill and decode.

1) Graph-safe decode. launch_fattn's f16 dequant temp buffers (K_f16/V_f16) used
   raw cudaMalloc/cudaFree during graph capture, which is illegal and crashed
   decode on the first step ("operation not permitted when stream is capturing").
   Allocation is now capture-aware: pool alloc while capturing / for small batches,
   raw alloc+free for large eager prefill (keeps VRAM bounded on the no-VMM card).
   Decode (Q->ne[1] <= 2) routes to the graph-safe VEC kernel (inline dequant, no
   temp buffer). Same class of decode crash fixed canonically upstream in
   TheTom#176 (merged, 7985f6b); this adapts it to the newer base.

2) Inline-dequant TILE prefill. The TILE/MMA path hardcoded need_f16_K/V=true and
   materialized the whole KV cache to an f16 temp buffer every step. A new TILE path
   inline-dequantizes turbo3 K/V during the global->shared tile load (no f16
   materialization); turbo3 head_dim=256 multi-row batches (Q->ne[1] >= 3, prefill +
   spec-verify) route to it. Prefill 1.39x/1.86x/1.69x faster (pp512/2048/4096),
   within ~6-12% of f16; the gain grows with context length.

Correctness: test-backend-ops -o FLASH_ATTN_EXT, turbo3 hsk=256,
nb in {3,4,6,64,128,256}, kv in {512,1024}: 12/12 OK (NMSE within tol vs CPU ref).
Scope: validated for turbo3/turbo3 head_dim=256 (Gemma-4 family); other dims/types
keep the f16 path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants