Skip to content

[executorch][cuda] fuse gate/up MLP projections #20482

Merged
Gasoonjia merged 3 commits into
mainfrom
gemma4_31b-cuda-decode-speedup
Jun 25, 2026
Merged

[executorch][cuda] fuse gate/up MLP projections #20482
Gasoonjia merged 3 commits into
mainfrom
gemma4_31b-cuda-decode-speedup

Conversation

@Gasoonjia

Copy link
Copy Markdown
Contributor

Summary:
Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct).

decode length main branch current branch
512 42.2 44.80
2K 40.8 43.20
8K 40.0 42.23
32K 39.4 41.64
127K 35.5 38.41

Next Step: we will upsteam this kind of operator fusion into gemma4-31b model level when loading gguf. #20481 is the draft PR

Three CUDA-export memory optimizations:

- tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset
  is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner
  auto-prunes configs that exceed a GPU's shared memory (OutOfResources ->
  inf), so the same config list also works on the 5090 (Blackwell, ~101 KB
  SMEM) where the previous smallest config did not fit.

- int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized
  weights (N>65536, i.e. only the lm_head). Avoids transiently materializing
  the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom
  op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a
  shim and the M>4 prefill inline path is below the threshold, so this never
  enters the runtime graph -> zero runtime / accuracy impact. Applied
  unconditionally (no flag).

- cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache
  buffers during AOTI compile (gated behind low_memory_mode). A new
  move_program_to_device hook places KV constants on the target device but
  immediately frees their storage (resize_(0)), so the fake-tensor device
  check passes while no real KV bytes sit on the GPU during autotune. The
  emptied buffers are re-synthesized as zeros at the _unlift_graph clone and
  at serialization, and excluded from constant dedup (resize_(0) gives every
  KV data_ptr 0, which would otherwise collapse same-shape caches across
  layers).

Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the
exported model runs correctly (output "...Paris.").
@pytorch-bot

pytorch-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20482

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Pending, 4 Unrelated Failures, 1 Unclassified Failure

As of commit 4025660 with merge base 1b726b2 (image):

NEW FAILURES - The following jobs have failed:

UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2026
@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 24, 2026

Copy link
Copy Markdown

CLA Signed
The committers listed above are authorized under a signed CLA.

  • ✅ login: Gasoonjia / name: Songhao Jia (4025660)

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia force-pushed the gemma4_31b-cuda-decode-speedup branch from 8b145b5 to 1c371e2 Compare June 24, 2026 10:00
@Gasoonjia Gasoonjia changed the base branch from main to gemma4_31b_export_under_32gb June 24, 2026 10:01
…nly frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T
Summary:
Fuse each gemma4_31b MLP's gate_proj|up_proj into a single
[2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA
export. This issues one activation-quant + one W4A8 matvec per layer instead of
two, cutting per-token launch + activation-quant overhead in the launch-bound
decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any
other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct).

Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b
call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to
~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at
every measured context (cuda_graph ON):

  ctx    ET      llama
  512    44.80   42.77
  2K     43.20   41.97
  8K     42.23   41.23
  32K    41.64   40.27
  127K   38.41   35.97

TurboQuant KV compression kept; prefill restored (6-8x) with no regression;
output quality preserved.

Test Plan:
- Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm
  kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and
  prefill (T=4).
- Export + run: fused module exported via CudaPartitioner and executed through
  executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs
  "Fused gate+up on 60 MLP layers".
- Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats
  llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode.
@Gasoonjia Gasoonjia force-pushed the gemma4_31b-cuda-decode-speedup branch from 1c371e2 to 4025660 Compare June 25, 2026 17:23
Base automatically changed from gemma4_31b_export_under_32gb to main June 25, 2026 22:29
@Gasoonjia Gasoonjia merged commit efc7560 into main Jun 25, 2026
248 of 415 checks passed
@Gasoonjia Gasoonjia deleted the gemma4_31b-cuda-decode-speedup branch June 25, 2026 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants