Commit 7e0151e
authored
[gemma4_31b][cuda] length-aware bf16 global attention (#20506)
- Global (full-attention) bf16 layers: bound SDPA to a runtime kv_len
scalar (CUDA-graph-safe) instead of the full max_seq_len KV buffer,
reduce complexity of decode from O(kvcache_length) -> O(context);
restores decode scaling (was flat ~36.5 t/s at all depths -> 46.5@512,
34.9@127K).
- Prefill global full-attention: replace fixed m32/m64 BLOCK_M selection
with a head_dim-keyed autotuned _sdpa_fwd_kernel + register-budget prune
(BLOCK_M*HEAD_DIM <= 4096*num_warps), fixing acc[64,512] fp32 register
spill at head_dim=512. Prefill +24% @8K, +63% @32k, +117% @127k;
head_dim-agnostic (no split-D needed for D<=512). (sdpa.py)
- Validated: output bitwise-identical to prior kernel (cos=1.0,
D=64/128/256/512), no decode regression; non-tq prefill now beats
llama.cpp at all 5 cells and turbo TQ4 at 4/5. Op-level autotune
profiling (A100) confirms the config set is near-optimal (in-set optimum
at every regime; only <=1.3% marginal candidates).1 parent efc7560 commit 7e0151e
2 files changed
Lines changed: 289 additions & 114 deletions
File tree
- backends/cuda/triton/kernels
- examples/models/gemma4_31b
0 commit comments