[gemma4_31b][cuda] length-aware bf16 global attention#20506
Conversation
Three CUDA-export memory optimizations: - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights (N>65536, i.e. only the lm_head). Avoids transiently materializing the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a shim and the M>4 prefill inline path is below the threshold, so this never enters the runtime graph -> zero runtime / accuracy impact. Applied unconditionally (no flag). - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile (gated behind low_memory_mode). A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the exported model runs correctly (output "...Paris.").
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20506
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
fbe12b9 to
ce442fe
Compare
…nly frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T
Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode.
…stic prefill autotune - Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len scalar (CUDA-graph-safe) instead of the full max_seq_len KV buffer -> O(context) decode; restores decode scaling (was flat ~36.5 t/s at all depths -> 46.5@512, 34.9@127K). (sdpa.py kv_len path + cuda_source_transformations.py _lenaware_attention_forward; global layers only, sliding + turbo untouched) - Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection with a head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune (BLOCK_M*HEAD_DIM <= 4096*num_warps), fixing acc[64,512] fp32 register spill at head_dim=512. Prefill +24% @8K, +63% @32k, +117% @127k; head_dim-agnostic (no split-D needed for D<=512). (sdpa.py) - Validated: output bitwise-identical to prior kernel (cos=1.0, D=64/128/256/512), no decode regression; non-tq prefill now beats llama.cpp at all 5 cells and turbo TQ4 at 4/5. Op-level autotune profiling (A100) confirms the config set is near-optimal (in-set optimum at every regime; only <=1.3% marginal candidates).
ce442fe to
74c2a9d
Compare
1c371e2 to
4025660
Compare
There was a problem hiding this comment.
Is the right threshold device specific?
This PR needs a
|
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 49s —— View job Code Review
Reviewed both files against A few notes below — none blocking. Device-specific thresholds (answers @metascroy's question)The reviewer asked whether the threshold is device-specific. It is — and there are now three A100-tuned magic constants in this kernel:
These are safe rather than correct-only-on-A100: the prune guarantees a non-empty config set (the If you want to make these portable without a re-tune, consider deriving from Stale comment / heuristic mismatch at the pack-GQA decision (
|
Uh oh!
There was an error while loading. Please reload this page.