[cuda backend][gemma4_31b] TQ4 SDPA: no-spill prefill kernel + analytic causal #20512
Conversation
Three CUDA-export memory optimizations: - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights (N>65536, i.e. only the lm_head). Avoids transiently materializing the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a shim and the M>4 prefill inline path is below the threshold, so this never enters the runtime graph -> zero runtime / accuracy impact. Applied unconditionally (no flag). - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile (gated behind low_memory_mode). A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the exported model runs correctly (output "...Paris.").
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20512
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 196 PendingAs of commit 4338b02 with merge base 7e0151e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…nly frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T
Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode.
…stic prefill autotune - Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len scalar (CUDA-graph-safe) instead of the full max_seq_len KV buffer -> O(context) decode; restores decode scaling (was flat ~36.5 t/s at all depths -> 46.5@512, 34.9@127K). (sdpa.py kv_len path + cuda_source_transformations.py _lenaware_attention_forward; global layers only, sliding + turbo untouched) - Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection with a head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune (BLOCK_M*HEAD_DIM <= 4096*num_warps), fixing acc[64,512] fp32 register spill at head_dim=512. Prefill +24% @8K, +63% @32k, +117% @127k; head_dim-agnostic (no split-D needed for D<=512). (sdpa.py) - Validated: output bitwise-identical to prior kernel (cos=1.0, D=64/128/256/512), no decode regression; non-tq prefill now beats llama.cpp at all 5 cells and turbo TQ4 at 4/5. Op-level autotune profiling (A100) confirms the config set is near-optimal (in-set optimum at every regime; only <=1.3% marginal candidates).
…tune Prefill (global head_dim=512 TQ4 path): - Consolidate m64/m32 into one no-spill _tq4_sdpa_prefill_kernel (BLOCK_M<=32 so acc[BLOCK_M,512] fp32 stays in registers, no spill). - Absolute-offset analytic causal (offs_n > (kv_len-Lq)+seq_pos); kernel no longer reads a materialized causal mask. - Prune prefill autotune list to the heavy-kv optima; also removes BLOCK_N=16 configs that produced incorrect output. Decode (split-K): retune autotune list for the HAS_MASK=False path - add the profiled optima (BLOCK_N=32/w4/s2, BLOCK_N=64/w8/s3), drop configs that are catastrophic at HAS_MASK=False (BLOCK_N=64/w2, BLOCK_N=128/w4). Call-site passes attn_mask=None for BOTH prefill and decode so the two exported methods share one AOTI weights blob (avoids 2x .ptd weight duplication). e2e @127k (A100, vs same-tree baseline): prefill 543->715 t/s (1.32x), decode 34.1->34.5 t/s, peak 25.75 GB (zero extra memory). Smaller contexts beat llama on both prefill and decode.
ce442fe to
74c2a9d
Compare
a61ff70 to
2a93d53
Compare
|
@claude review this code |
|
Claude encountered an error —— View job I'll analyze this and get back to you. |
This PR needs a
|
Summary
Speeds up long-context prefill for the TurboQuant (TQ4) KV-cache SDPA path in
gemma4_31b (the 10 global/full-attention layers, head_dim=512), with no decode
regression and no extra memory. At 127K context the prefill gap vs llama.cpp
goes from -37% to -7%; shorter contexts already beat llama.cpp on both prefill
and decode.
Changes
Prefill kernel (
backends/cuda/triton/kernels/tq4_sdpa.py)m64/m32prefill kernels into one no-spill_tq4_sdpa_prefill_kernel: capBLOCK_M<=32soacc[BLOCK_M, 512]fp32 staysin registers instead of spilling to local memory.
offs_n > (kv_len - Lq) + seq_pos); thekernel no longer reads a materialized causal mask.
BLOCK_N=16configs (slowAND numerically incorrect, cos≈0.02).
Decode split-K (
tq4_sdpa.py)HAS_MASK=Falsespecialization: addthe profiled optima (
BLOCK_N=32/w4/s2,BLOCK_N=64/w8/s3); drop configs thatare catastrophically slow at
HAS_MASK=False(BLOCK_N=64/w2,BLOCK_N=128/w4).Call site (
examples/models/gemma4_31b/cuda_source_transformations.py)attn_mask=Nonefor BOTH prefill and decode so the two exported methodsemit an identical SDPA call → AOTI dedups the weights blob (avoids a 2×
.ptd/ 52GB blow-up; keeps ~26GB).
Cross-GPU readiness
BLOCK_N∈{32,64})that fit smaller-SMEM GPUs (e.g. RTX 5090, ~100KB/SM vs A100 164KB). A100
optima are retained. NOTE: AOTI bakes the config at export time, so re-export
on the target GPU; 5090 perf/correctness still to be validated on a 5090.
Results (e2e, A100, prefill t/s | decode t/s | peak; same-tree baseline)
on both prefill and decode.
Test plan
CUDA_VISIBLE_DEVICES=0 python -m pytest backends/cuda/tests/test_tq4_sdpa.py -q→ 36 passed (kernel correctness across MHA/GQA/MQA, causal, decode, HD256,
all-masked NaN-safety, 128K bottom-right alignment) + AOTI export; 1 gated-skip.
ignore_eos, 512 decode); table above. Verified
.ptd= 1 weights blob (~26GB);baked configs = prefill
BM32/BN32/w4, decodeBN32/w4/s2.Known follow-ups (not in this PR)
BLOCK_N=16correctness bug (root cause unfixed; worked around by pruning).