Skip to content

[gemma4_31b][cuda] length-aware bf16 global attention#20506

Merged
Gasoonjia merged 5 commits into
mainfrom
gemma4_31b-cuda-attn-perf-git
Jun 25, 2026
Merged

[gemma4_31b][cuda] length-aware bf16 global attention#20506
Gasoonjia merged 5 commits into
mainfrom
gemma4_31b-cuda-attn-perf-git

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 25, 2026

Copy link
Copy Markdown
Contributor
  • Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len scalar (CUDA-graph-safe) instead of the full max_seq_len KV buffer, reduce complexity of decode from O(kvcache_length) -> O(context); restores decode scaling (was flat ~36.5 t/s at all depths -> 46.5@512, 34.9@127K).
  • Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection with a head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune (BLOCK_MHEAD_DIM <= 4096num_warps), fixing acc[64,512] fp32 register spill at head_dim=512. Prefill +24% @8K, +63% @32k, +117% @127k; head_dim-agnostic (no split-D needed for D<=512). (sdpa.py)
  • Validated: output bitwise-identical to prior kernel (cos=1.0, D=64/128/256/512), no decode regression; non-tq prefill now beats llama.cpp at all 5 cells and turbo TQ4 at 4/5. Op-level autotune profiling (A100) confirms the config set is near-optimal (in-set optimum at every regime; only <=1.3% marginal candidates).

Three CUDA-export memory optimizations:

- tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset
  is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner
  auto-prunes configs that exceed a GPU's shared memory (OutOfResources ->
  inf), so the same config list also works on the 5090 (Blackwell, ~101 KB
  SMEM) where the previous smallest config did not fit.

- int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized
  weights (N>65536, i.e. only the lm_head). Avoids transiently materializing
  the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom
  op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a
  shim and the M>4 prefill inline path is below the threshold, so this never
  enters the runtime graph -> zero runtime / accuracy impact. Applied
  unconditionally (no flag).

- cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache
  buffers during AOTI compile (gated behind low_memory_mode). A new
  move_program_to_device hook places KV constants on the target device but
  immediately frees their storage (resize_(0)), so the fake-tensor device
  check passes while no real KV bytes sit on the GPU during autotune. The
  emptied buffers are re-synthesized as zeros at the _unlift_graph clone and
  at serialization, and excluded from constant dedup (resize_(0) gives every
  KV data_ptr 0, which would otherwise collapse same-shape caches across
  layers).

Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the
exported model runs correctly (output "...Paris.").
@pytorch-bot

pytorch-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20506

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 25, 2026

Copy link
Copy Markdown

CLA Signed
The committers listed above are authorized under a signed CLA.

  • ✅ login: Gasoonjia / name: Songhao Jia (74c2a9d)

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 25, 2026
@Gasoonjia Gasoonjia force-pushed the gemma4_31b-cuda-attn-perf-git branch from fbe12b9 to ce442fe Compare June 25, 2026 06:07
@Gasoonjia Gasoonjia changed the title [gemma4_31b][cuda] length-aware bf16 global attention + head_dim-agno… [gemma4_31b][cuda] length-aware bf16 global attention Jun 25, 2026
@Gasoonjia Gasoonjia marked this pull request as ready for review June 25, 2026 14:58
…nly frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T
Summary:
Fuse each gemma4_31b MLP's gate_proj|up_proj into a single
[2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA
export. This issues one activation-quant + one W4A8 matvec per layer instead of
two, cutting per-token launch + activation-quant overhead in the launch-bound
decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any
other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct).

Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b
call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to
~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at
every measured context (cuda_graph ON):

  ctx    ET      llama
  512    44.80   42.77
  2K     43.20   41.97
  8K     42.23   41.23
  32K    41.64   40.27
  127K   38.41   35.97

TurboQuant KV compression kept; prefill restored (6-8x) with no regression;
output quality preserved.

Test Plan:
- Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm
  kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and
  prefill (T=4).
- Export + run: fused module exported via CudaPartitioner and executed through
  executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs
  "Fused gate+up on 60 MLP layers".
- Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats
  llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode.
…stic prefill autotune

- Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len scalar
  (CUDA-graph-safe) instead of the full max_seq_len KV buffer -> O(context)
  decode; restores decode scaling (was flat ~36.5 t/s at all depths ->
  46.5@512, 34.9@127K). (sdpa.py kv_len path + cuda_source_transformations.py
  _lenaware_attention_forward; global layers only, sliding + turbo untouched)
- Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection with a
  head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune
  (BLOCK_M*HEAD_DIM <= 4096*num_warps), fixing acc[64,512] fp32 register spill
  at head_dim=512. Prefill +24% @8K, +63% @32k, +117% @127k; head_dim-agnostic
  (no split-D needed for D<=512). (sdpa.py)
- Validated: output bitwise-identical to prior kernel (cos=1.0, D=64/128/256/512),
  no decode regression; non-tq prefill now beats llama.cpp at all 5 cells and
  turbo TQ4 at 4/5. Op-level autotune profiling (A100) confirms the config set is
  near-optimal (in-set optimum at every regime; only <=1.3% marginal candidates).
@Gasoonjia Gasoonjia force-pushed the gemma4_31b-cuda-attn-perf-git branch from ce442fe to 74c2a9d Compare June 25, 2026 17:23
@Gasoonjia Gasoonjia force-pushed the gemma4_31b-cuda-decode-speedup branch from 1c371e2 to 4025660 Compare June 25, 2026 17:23
Base automatically changed from gemma4_31b-cuda-decode-speedup to main June 25, 2026 22:29

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the right threshold device specific?

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 2m 49s —— View job


Code Review

  • Read the changed files
  • Review sdpa.py (kv_len bound, split-K dispatch, autotune/prune)
  • Review cuda_source_transformations.py (length-aware global attention)
  • Post review feedback

Reviewed both files against origin/main. This is clean, well-documented work — the kv_len bound is correctly plumbed through both the standard and split-K paths, kept on-device (no .item()) so it survives CUDA-graph replay, and clamped to L_kv for OOB safety. The HAS_KV_LEN=False fallback preserves the original behavior exactly. Output bitwise-identity claim is plausible given the loop bound only skips positions that were empty/future anyway.

A few notes below — none blocking.


Device-specific thresholds (answers @metascroy's question)

The reviewer asked whether the threshold is device-specific. It is — and there are now three A100-tuned magic constants in this kernel:

  • sdpa.py:582BLOCK_M * HEAD_DIM <= 4096 * num_warps, explicitly documented as "the measured A100 no-spill boundary". The register-file budget per SM differs across archs (e.g. consumer/Ada vs. A100/H100), so this prune boundary is genuinely device-specific.
  • sdpa.py _SPLITK_LKV_THRESHOLD = 256 — the split-K cutover point.
  • sdpa.py:1047threshold = 4 * 84 (84 = A100 SM count) for the pack-GQA block_m heuristic.

These are safe rather than correct-only-on-A100: the prune guarantees a non-empty config set (the BLOCK_M=16 configs satisfy the budget for any HEAD_DIM<=1024, plus the explicit fallback), and the autotuner benchmarks the survivors, so a suboptimal-but-valid config is always chosen on other GPUs. Worst case is leaving perf on the table off-A100, not incorrectness.

If you want to make these portable without a re-tune, consider deriving from torch.cuda.get_device_properties(dev).multi_processor_count (replacing the hardcoded 84) and the arch's regs/SM. Not required for this PR, but worth a TODO so the A100 assumption is discoverable.

Stale comment / heuristic mismatch at the pack-GQA decision (sdpa.py:1039-1047)

# Use the actual BLOCK_M that the launched kernel will use:
# - non-pow2 path always uses BLOCK_M=32
# - pow2 path selects M32 or M64 based on CTA occupancy
...
block_m = 32 if total_ctas_m64 < 4 * 84 else 64

This comment and block_m guess are now stale. _sdpa_fwd_kernel is autotuned over BLOCK_M ∈ {16..128}, so the kernel's actual BLOCK_M is no longer the 32/64 value computed here. block_m only feeds the _should_pack_gqa heuristic (which is purely a perf decision, not correctness), so this isn't a bug — but the comment's "actual BLOCK_M that the launched kernel will use" is no longer true and is misleading to the next reader. Suggest updating the comment to note it's an approximation for the pack-GQA heuristic only. Fix this →

_sdpa_prefill_prune HEAD_DIM lookup robustness (sdpa.py:582)

The prune falls back to returning all configs if HEAD_DIM is found in neither kwargs nor nargs. In that (unexpected) case the BLOCK_M=128 config could be selected for HEAD_DIM=512 — the exact acc[128,512] fp32 spill this prune exists to prevent. In practice HEAD_DIM is always passed as a constexpr kwarg to the launch, so this is defensive-only, but since the whole point of the prune is spill-avoidance, consider asserting HEAD_DIM is not None rather than silently returning the unpruned (potentially-spilling) set.

Minor / confirmed-correct

  • Split-K empty splits (sdpa.py:1199-1208): splits with start_n >= kv_len get end_n <= start_n, skip the loop, and store the in-kernel zero-initialized acc/l_i. Since all H_q = H_kv * NUM_GROUPS slots are written, the reduce sees valid zeros — correct. O_partial being torch.empty (vs. zeros) is fine because every (split, head) slot is unconditionally stored.
  • kv_len = input_pos[0] + input_pos.shape[0] (cuda_source_transformations.py): correct only if input_pos is a contiguous range; that's the invariant for decode (T=1) and prefill chunks here, so fine — just flagging the assumption.
  • The new import ...kernels.sdpa # noqa: F401 for op registration mirrors the existing tq4_sdpa import — consistent.

Nice work — the O(context) decode scaling restoration and the head_dim-agnostic prefill autotune are well-motivated and the validation (cos=1.0 across D, no decode regression) is reassuring.
• branch gemma4_31b-cuda-attn-perf-git

@Gasoonjia Gasoonjia merged commit 7e0151e into main Jun 25, 2026
233 of 238 checks passed
@Gasoonjia Gasoonjia deleted the gemma4_31b-cuda-attn-perf-git branch June 25, 2026 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants