Skip to content

[cuda backend][gemma4_31b] TQ4 SDPA: no-spill prefill kernel + analytic causal #20512

Merged
Gasoonjia merged 7 commits into
mainfrom
gemma4_31b-tq4-prefill-decode-tuned
Jun 25, 2026
Merged

[cuda backend][gemma4_31b] TQ4 SDPA: no-spill prefill kernel + analytic causal #20512
Gasoonjia merged 7 commits into
mainfrom
gemma4_31b-tq4-prefill-decode-tuned

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Summary

Speeds up long-context prefill for the TurboQuant (TQ4) KV-cache SDPA path in
gemma4_31b (the 10 global/full-attention layers, head_dim=512), with no decode
regression and no extra memory. At 127K context the prefill gap vs llama.cpp
goes from -37% to -7%; shorter contexts already beat llama.cpp on both prefill
and decode.

Changes

Prefill kernel (backends/cuda/triton/kernels/tq4_sdpa.py)

  • Consolidate the m64/m32 prefill kernels into one no-spill
    _tq4_sdpa_prefill_kernel: cap BLOCK_M<=32 so acc[BLOCK_M, 512] fp32 stays
    in registers instead of spilling to local memory.
  • Absolute-offset analytic causal (offs_n > (kv_len - Lq) + seq_pos); the
    kernel no longer reads a materialized causal mask.
  • Autotune list tuned to the heavy-kv shape; removed BLOCK_N=16 configs (slow
    AND numerically incorrect, cos≈0.02).

Decode split-K (tq4_sdpa.py)

  • Retune the split-K autotune list for the HAS_MASK=False specialization: add
    the profiled optima (BLOCK_N=32/w4/s2, BLOCK_N=64/w8/s3); drop configs that
    are catastrophically slow at HAS_MASK=False (BLOCK_N=64/w2, BLOCK_N=128/w4).

Call site (examples/models/gemma4_31b/cuda_source_transformations.py)

  • Pass attn_mask=None for BOTH prefill and decode so the two exported methods
    emit an identical SDPA call → AOTI dedups the weights blob (avoids a 2× .ptd
    / 52GB blow-up; keeps ~26GB).

Cross-GPU readiness

  • Add correctness-safe autotune configs (warp/stage variants on BLOCK_N∈{32,64})
    that fit smaller-SMEM GPUs (e.g. RTX 5090, ~100KB/SM vs A100 164KB). A100
    optima are retained. NOTE: AOTI bakes the config at export time, so re-export
    on the target GPU; 5090 perf/correctness still to be validated on a 5090.

Results (e2e, A100, prefill t/s | decode t/s | peak; same-tree baseline)

ctx this branch baseline llama.cpp
32K 1554 / 42.3 1243 / 42.1 1279 / 40.1
127K 715 / 34.5 / 25.75GB 543 / 34.1 768.5 / 33.0
  • 127K prefill 1.32× (gap vs llama.cpp -37% → -7%); 512/2K/8K/32K beat llama
    on both prefill and decode.
  • 127K decode 34.5 (≥ baseline 34.1, > llama 33.0) — no regression.
  • Peak memory 25.75GB — zero extra vs baseline.

Test plan

  • CUDA_VISIBLE_DEVICES=0 python -m pytest backends/cuda/tests/test_tq4_sdpa.py -q
    → 36 passed (kernel correctness across MHA/GQA/MQA, causal, decode, HD256,
    all-masked NaN-safety, 128K bottom-right alignment) + AOTI export; 1 gated-skip.
  • e2e: exported gemma4_31b .pte, ran 512/2K/8K/32K/127K (cuda_graph, temp 0,
    ignore_eos, 512 decode); table above. Verified .ptd = 1 weights blob (~26GB);
    baked configs = prefill BM32/BN32/w4, decode BN32/w4/s2.

Known follow-ups (not in this PR)

  • BLOCK_N=16 correctness bug (root cause unfixed; worked around by pruning).
  • Correctness-gated + GPU-adaptive autotune (only partial here); validate on 5090.
  • INT4 MLP W4A8 GEMV dominates decode (~76%) — separate effort.

Three CUDA-export memory optimizations:

- tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset
  is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner
  auto-prunes configs that exceed a GPU's shared memory (OutOfResources ->
  inf), so the same config list also works on the 5090 (Blackwell, ~101 KB
  SMEM) where the previous smallest config did not fit.

- int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized
  weights (N>65536, i.e. only the lm_head). Avoids transiently materializing
  the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom
  op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a
  shim and the M>4 prefill inline path is below the threshold, so this never
  enters the runtime graph -> zero runtime / accuracy impact. Applied
  unconditionally (no flag).

- cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache
  buffers during AOTI compile (gated behind low_memory_mode). A new
  move_program_to_device hook places KV constants on the target device but
  immediately frees their storage (resize_(0)), so the fake-tensor device
  check passes while no real KV bytes sit on the GPU during autotune. The
  emptied buffers are re-synthesized as zeros at the _unlift_graph clone and
  at serialization, and excluded from constant dedup (resize_(0) gives every
  KV data_ptr 0, which would otherwise collapse same-shape caches across
  layers).

Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the
exported model runs correctly (output "...Paris.").
@pytorch-bot

pytorch-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20512

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 196 Pending

As of commit 4338b02 with merge base 7e0151e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 25, 2026
@Gasoonjia Gasoonjia changed the title Gemma4 31b tq4 prefill decode tuned [cuda backend][gemma4_31b] TQ4 SDPA: no-spill prefill kernel + analytic causal Jun 25, 2026
@Gasoonjia Gasoonjia marked this pull request as ready for review June 25, 2026 15:32
Gasoonjia and others added 5 commits June 25, 2026 08:55
…nly frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T
Summary:
Fuse each gemma4_31b MLP's gate_proj|up_proj into a single
[2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA
export. This issues one activation-quant + one W4A8 matvec per layer instead of
two, cutting per-token launch + activation-quant overhead in the launch-bound
decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any
other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct).

Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b
call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to
~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at
every measured context (cuda_graph ON):

  ctx    ET      llama
  512    44.80   42.77
  2K     43.20   41.97
  8K     42.23   41.23
  32K    41.64   40.27
  127K   38.41   35.97

TurboQuant KV compression kept; prefill restored (6-8x) with no regression;
output quality preserved.

Test Plan:
- Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm
  kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and
  prefill (T=4).
- Export + run: fused module exported via CudaPartitioner and executed through
  executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs
  "Fused gate+up on 60 MLP layers".
- Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats
  llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode.
…stic prefill autotune

- Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len scalar
  (CUDA-graph-safe) instead of the full max_seq_len KV buffer -> O(context)
  decode; restores decode scaling (was flat ~36.5 t/s at all depths ->
  46.5@512, 34.9@127K). (sdpa.py kv_len path + cuda_source_transformations.py
  _lenaware_attention_forward; global layers only, sliding + turbo untouched)
- Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection with a
  head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune
  (BLOCK_M*HEAD_DIM <= 4096*num_warps), fixing acc[64,512] fp32 register spill
  at head_dim=512. Prefill +24% @8K, +63% @32k, +117% @127k; head_dim-agnostic
  (no split-D needed for D<=512). (sdpa.py)
- Validated: output bitwise-identical to prior kernel (cos=1.0, D=64/128/256/512),
  no decode regression; non-tq prefill now beats llama.cpp at all 5 cells and
  turbo TQ4 at 4/5. Op-level autotune profiling (A100) confirms the config set is
  near-optimal (in-set optimum at every regime; only <=1.3% marginal candidates).
…tune

Prefill (global head_dim=512 TQ4 path):
- Consolidate m64/m32 into one no-spill _tq4_sdpa_prefill_kernel (BLOCK_M<=32
  so acc[BLOCK_M,512] fp32 stays in registers, no spill).
- Absolute-offset analytic causal (offs_n > (kv_len-Lq)+seq_pos); kernel no
  longer reads a materialized causal mask.
- Prune prefill autotune list to the heavy-kv optima; also removes BLOCK_N=16
  configs that produced incorrect output.

Decode (split-K): retune autotune list for the HAS_MASK=False path - add the
profiled optima (BLOCK_N=32/w4/s2, BLOCK_N=64/w8/s3), drop configs that are
catastrophic at HAS_MASK=False (BLOCK_N=64/w2, BLOCK_N=128/w4).

Call-site passes attn_mask=None for BOTH prefill and decode so the two exported
methods share one AOTI weights blob (avoids 2x .ptd weight duplication).

e2e @127k (A100, vs same-tree baseline): prefill 543->715 t/s (1.32x), decode
34.1->34.5 t/s, peak 25.75 GB (zero extra memory). Smaller contexts beat llama
on both prefill and decode.
@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude encountered an error —— View job


I'll analyze this and get back to you.

Base automatically changed from gemma4_31b-cuda-attn-perf-git to main June 25, 2026 22:35
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia merged commit 6021a58 into main Jun 25, 2026
234 of 238 checks passed
@Gasoonjia Gasoonjia deleted the gemma4_31b-tq4-prefill-decode-tuned branch June 25, 2026 22:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants