HIP/turbo3: graph-safe decode + inline-dequant TILE prefill on gfx1201 (RDNA4)#28
Open
KaiFelixBennett wants to merge 1 commit into
Conversation
…1 (RDNA4)
Make TurboQuant (turbo3) KV cache usable on ROCm with HIP graphs on gfx1201,
both prefill and decode.
1) Graph-safe decode. launch_fattn's f16 dequant temp buffers (K_f16/V_f16) used
raw cudaMalloc/cudaFree during graph capture, which is illegal and crashed
decode on the first step ("operation not permitted when stream is capturing").
Allocation is now capture-aware: pool alloc while capturing / for small batches,
raw alloc+free for large eager prefill (keeps VRAM bounded on the no-VMM card).
Decode (Q->ne[1] <= 2) routes to the graph-safe VEC kernel (inline dequant, no
temp buffer). Same class of decode crash fixed canonically upstream in
TheTom#176 (merged, 7985f6b); this adapts it to the newer base.
2) Inline-dequant TILE prefill. The TILE/MMA path hardcoded need_f16_K/V=true and
materialized the whole KV cache to an f16 temp buffer every step. A new TILE path
inline-dequantizes turbo3 K/V during the global->shared tile load (no f16
materialization); turbo3 head_dim=256 multi-row batches (Q->ne[1] >= 3, prefill +
spec-verify) route to it. Prefill 1.39x/1.86x/1.69x faster (pp512/2048/4096),
within ~6-12% of f16; the gain grows with context length.
Correctness: test-backend-ops -o FLASH_ATTN_EXT, turbo3 hsk=256,
nb in {3,4,6,64,128,256}, kv in {512,1024}: 12/12 OK (NMSE within tol vs CPU ref).
Scope: validated for turbo3/turbo3 head_dim=256 (Gemma-4 family); other dims/types
keep the f16 path.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Makes TurboQuant (turbo3) KV cache fully usable on ROCm / HIP graphs on gfx1201 (RDNA4) — both prefill and decode — and brings turbo3 prefill up to roughly f16 speed.
Measured on a Radeon AI PRO R9700 (gfx1201, 32 GB), Windows 11, HIP SDK 7.1, Gemma-4 Q4_K_M, HIP graphs ON. Nothing extrapolated.
Two coupled changes
1. Graph-safe decode (fixes the crash)
With HIP graphs on, turbo KV crashed on the first decode step:
Cause:
launch_fattn's f16 dequant temp buffers (K_f16/V_f16) used rawcudaMalloc/cudaFreeduring graph capture, which is illegal. Fix: capture-aware allocation — pool alloc while capturing / for small batches, raw alloc+free for large eager prefill so VRAM stays bounded on the no-VMM card — and decode (Q->ne[1] <= 2) routes to the graph-safe VEC kernel (inline dequant, no temp buffer).This is the same class of decode crash we fixed canonically upstream in TheTom#176 (merged,
7985f6b). This PR adapts that fix to this fork's newer base.2. Inline-dequant TILE prefill (makes prefill fast)
The TILE/MMA path hardcoded
need_f16_K/V = trueand materialized the whole KV cache to an f16 temp buffer every step — a per-step O(KV) dequant tax that negates the 3-bit cache, so turbo3 prefill was stuck on the slow sequential VEC kernel. A new TILE path inline-dequantizes turbo3 K/V during the global→shared tile load (no f16 materialization); turbo3 head_dim=256 multi-row batches (Q->ne[1] >= 3, prefill + spec-verify) route to it.turbo3 prefill is now within ~6–12% of f16, and the gain grows with context length — the long-context regime where the 3-bit cache earns its keep.
Correctness
test-backend-ops -o FLASH_ATTN_EXT, turbo3 hsk=256, nb ∈ {3,4,6 (verify), 64,128,256 (prefill)}, kv ∈ {512,1024}: 12/12 OK (NMSE within tolerance vs the CPU reference).Scope / honesty