Commit 6021a58
authored
[cuda backend][gemma4_31b] TQ4 SDPA: no-spill prefill kernel + analytic causal (#20512)
## Summary
Speeds up long-context prefill for the TurboQuant (TQ4) KV-cache SDPA
path in
gemma4_31b (the 10 global/full-attention layers, head_dim=512), with no
decode
regression and no extra memory. At 127K context the prefill gap vs
llama.cpp
goes from -37% to -7%; shorter contexts already beat llama.cpp on both
prefill
and decode.
### Changes
**Prefill kernel (`backends/cuda/triton/kernels/tq4_sdpa.py`)**
- Consolidate the `m64`/`m32` prefill kernels into one no-spill
`_tq4_sdpa_prefill_kernel`: cap `BLOCK_M<=32` so `acc[BLOCK_M, 512]`
fp32 stays
in registers instead of spilling to local memory.
- Absolute-offset analytic causal (`offs_n > (kv_len - Lq) + seq_pos`);
the
kernel no longer reads a materialized causal mask.
- Autotune list tuned to the heavy-kv shape; removed `BLOCK_N=16`
configs (slow
AND numerically incorrect, cos≈0.02).
**Decode split-K (`tq4_sdpa.py`)**
- Retune the split-K autotune list for the `HAS_MASK=False`
specialization: add
the profiled optima (`BLOCK_N=32/w4/s2`, `BLOCK_N=64/w8/s3`); drop
configs that
are catastrophically slow at `HAS_MASK=False` (`BLOCK_N=64/w2`,
`BLOCK_N=128/w4`).
**Call site
(`examples/models/gemma4_31b/cuda_source_transformations.py`)**
- Pass `attn_mask=None` for BOTH prefill and decode so the two exported
methods
emit an identical SDPA call → AOTI dedups the weights blob (avoids a 2×
`.ptd`
/ 52GB blow-up; keeps ~26GB).
**Cross-GPU readiness**
- Add correctness-safe autotune configs (warp/stage variants on
`BLOCK_N∈{32,64}`)
that fit smaller-SMEM GPUs (e.g. RTX 5090, ~100KB/SM vs A100 164KB).
A100
optima are retained. NOTE: AOTI bakes the config at export time, so
re-export
on the target GPU; 5090 perf/correctness still to be validated on a
5090.
### Results (e2e, A100, prefill t/s | decode t/s | peak; same-tree
baseline)
| ctx | this branch | baseline | llama.cpp |
|------|--------------------|---------------|---------------|
| 32K | 1554 / 42.3 | 1243 / 42.1 | 1279 / 40.1 |
| 127K | **715 / 34.5 / 25.75GB** | 543 / 34.1 | 768.5 / 33.0 |
- 127K prefill **1.32×** (gap vs llama.cpp -37% → -7%); 512/2K/8K/32K
beat llama
on both prefill and decode.
- 127K decode **34.5** (≥ baseline 34.1, > llama 33.0) — no regression.
- Peak memory **25.75GB** — zero extra vs baseline.
## Test plan
- `CUDA_VISIBLE_DEVICES=0 python -m pytest
backends/cuda/tests/test_tq4_sdpa.py -q`
→ 36 passed (kernel correctness across MHA/GQA/MQA, causal, decode,
HD256,
all-masked NaN-safety, 128K bottom-right alignment) + AOTI export; 1
gated-skip.
- e2e: exported gemma4_31b .pte, ran 512/2K/8K/32K/127K (cuda_graph,
temp 0,
ignore_eos, 512 decode); table above. Verified `.ptd` = 1 weights blob
(~26GB);
baked configs = prefill `BM32/BN32/w4`, decode `BN32/w4/s2`.
### Known follow-ups (not in this PR)
- `BLOCK_N=16` correctness bug (root cause unfixed; worked around by
pruning).
- Correctness-gated + GPU-adaptive autotune (only partial here);
validate on 5090.
- INT4 MLP W4A8 GEMV dominates decode (~76%) — separate effort.1 parent 7e0151e commit 6021a58
2 files changed
Lines changed: 45 additions & 156 deletions
File tree
- backends/cuda/triton/kernels
- examples/models/gemma4_31b
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
194 | 194 | | |
195 | 195 | | |
196 | 196 | | |
197 | | - | |
| 197 | + | |
198 | 198 | | |
199 | 199 | | |
200 | 200 | | |
| |||
227 | 227 | | |
228 | 228 | | |
229 | 229 | | |
230 | | - | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
231 | 236 | | |
232 | 237 | | |
233 | 238 | | |
| |||
283 | 288 | | |
284 | 289 | | |
285 | 290 | | |
286 | | - | |
| 291 | + | |
287 | 292 | | |
288 | 293 | | |
289 | 294 | | |
290 | 295 | | |
291 | 296 | | |
292 | | - | |
293 | | - | |
294 | | - | |
295 | | - | |
296 | | - | |
297 | | - | |
298 | | - | |
299 | | - | |
300 | | - | |
301 | | - | |
302 | | - | |
303 | | - | |
304 | | - | |
305 | | - | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
311 | | - | |
312 | | - | |
313 | | - | |
314 | | - | |
315 | | - | |
316 | | - | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | | - | |
344 | | - | |
345 | | - | |
346 | | - | |
347 | | - | |
348 | | - | |
349 | | - | |
350 | | - | |
351 | | - | |
352 | | - | |
353 | | - | |
354 | | - | |
355 | | - | |
356 | | - | |
357 | | - | |
358 | | - | |
359 | | - | |
360 | | - | |
361 | | - | |
362 | | - | |
363 | | - | |
364 | | - | |
365 | | - | |
366 | | - | |
367 | | - | |
368 | | - | |
369 | | - | |
370 | | - | |
371 | | - | |
372 | | - | |
373 | | - | |
374 | | - | |
375 | | - | |
376 | | - | |
377 | | - | |
378 | | - | |
379 | | - | |
380 | | - | |
381 | | - | |
382 | | - | |
383 | | - | |
384 | | - | |
385 | | - | |
386 | | - | |
387 | | - | |
388 | | - | |
389 | | - | |
390 | | - | |
391 | | - | |
392 | | - | |
393 | | - | |
394 | | - | |
395 | | - | |
396 | | - | |
397 | | - | |
398 | | - | |
399 | | - | |
400 | | - | |
401 | | - | |
402 | | - | |
403 | | - | |
404 | | - | |
405 | | - | |
406 | | - | |
407 | | - | |
408 | | - | |
409 | | - | |
410 | | - | |
411 | | - | |
412 | | - | |
413 | | - | |
414 | | - | |
415 | | - | |
| 297 | + | |
| 298 | + | |
416 | 299 | | |
417 | | - | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
418 | 305 | | |
419 | 306 | | |
420 | 307 | | |
421 | 308 | | |
422 | | - | |
| 309 | + | |
423 | 310 | | |
424 | 311 | | |
425 | 312 | | |
| |||
570 | 457 | | |
571 | 458 | | |
572 | 459 | | |
573 | | - | |
574 | | - | |
575 | | - | |
576 | | - | |
577 | | - | |
578 | | - | |
579 | | - | |
580 | | - | |
581 | | - | |
| 460 | + | |
582 | 461 | | |
583 | 462 | | |
584 | 463 | | |
| |||
845 | 724 | | |
846 | 725 | | |
847 | 726 | | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
| 734 | + | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
848 | 740 | | |
849 | 741 | | |
850 | 742 | | |
| |||
863 | 755 | | |
864 | 756 | | |
865 | 757 | | |
866 | | - | |
| 758 | + | |
867 | 759 | | |
868 | | - | |
| 760 | + | |
869 | 761 | | |
870 | 762 | | |
871 | 763 | | |
872 | | - | |
| 764 | + | |
873 | 765 | | |
874 | 766 | | |
875 | 767 | | |
| |||
889 | 781 | | |
890 | 782 | | |
891 | 783 | | |
892 | | - | |
893 | | - | |
894 | | - | |
895 | | - | |
896 | | - | |
897 | | - | |
898 | | - | |
899 | | - | |
| 784 | + | |
| 785 | + | |
900 | 786 | | |
901 | | - | |
902 | | - | |
| 787 | + | |
| 788 | + | |
| 789 | + | |
| 790 | + | |
| 791 | + | |
903 | 792 | | |
904 | 793 | | |
905 | 794 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
55 | 58 | | |
56 | 59 | | |
57 | 60 | | |
| |||
94 | 97 | | |
95 | 98 | | |
96 | 99 | | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | 100 | | |
101 | 101 | | |
102 | 102 | | |
| |||
105 | 105 | | |
106 | 106 | | |
107 | 107 | | |
108 | | - | |
109 | | - | |
| 108 | + | |
| 109 | + | |
110 | 110 | | |
111 | 111 | | |
112 | 112 | | |
| |||
0 commit comments