Skip to content

Commit 6021a58

Browse files
authored
[cuda backend][gemma4_31b] TQ4 SDPA: no-spill prefill kernel + analytic causal (#20512)
## Summary Speeds up long-context prefill for the TurboQuant (TQ4) KV-cache SDPA path in gemma4_31b (the 10 global/full-attention layers, head_dim=512), with no decode regression and no extra memory. At 127K context the prefill gap vs llama.cpp goes from -37% to -7%; shorter contexts already beat llama.cpp on both prefill and decode. ### Changes **Prefill kernel (`backends/cuda/triton/kernels/tq4_sdpa.py`)** - Consolidate the `m64`/`m32` prefill kernels into one no-spill `_tq4_sdpa_prefill_kernel`: cap `BLOCK_M<=32` so `acc[BLOCK_M, 512]` fp32 stays in registers instead of spilling to local memory. - Absolute-offset analytic causal (`offs_n > (kv_len - Lq) + seq_pos`); the kernel no longer reads a materialized causal mask. - Autotune list tuned to the heavy-kv shape; removed `BLOCK_N=16` configs (slow AND numerically incorrect, cos≈0.02). **Decode split-K (`tq4_sdpa.py`)** - Retune the split-K autotune list for the `HAS_MASK=False` specialization: add the profiled optima (`BLOCK_N=32/w4/s2`, `BLOCK_N=64/w8/s3`); drop configs that are catastrophically slow at `HAS_MASK=False` (`BLOCK_N=64/w2`, `BLOCK_N=128/w4`). **Call site (`examples/models/gemma4_31b/cuda_source_transformations.py`)** - Pass `attn_mask=None` for BOTH prefill and decode so the two exported methods emit an identical SDPA call → AOTI dedups the weights blob (avoids a 2× `.ptd` / 52GB blow-up; keeps ~26GB). **Cross-GPU readiness** - Add correctness-safe autotune configs (warp/stage variants on `BLOCK_N∈{32,64}`) that fit smaller-SMEM GPUs (e.g. RTX 5090, ~100KB/SM vs A100 164KB). A100 optima are retained. NOTE: AOTI bakes the config at export time, so re-export on the target GPU; 5090 perf/correctness still to be validated on a 5090. ### Results (e2e, A100, prefill t/s | decode t/s | peak; same-tree baseline) | ctx | this branch | baseline | llama.cpp | |------|--------------------|---------------|---------------| | 32K | 1554 / 42.3 | 1243 / 42.1 | 1279 / 40.1 | | 127K | **715 / 34.5 / 25.75GB** | 543 / 34.1 | 768.5 / 33.0 | - 127K prefill **1.32×** (gap vs llama.cpp -37% → -7%); 512/2K/8K/32K beat llama on both prefill and decode. - 127K decode **34.5** (≥ baseline 34.1, > llama 33.0) — no regression. - Peak memory **25.75GB** — zero extra vs baseline. ## Test plan - `CUDA_VISIBLE_DEVICES=0 python -m pytest backends/cuda/tests/test_tq4_sdpa.py -q` → 36 passed (kernel correctness across MHA/GQA/MQA, causal, decode, HD256, all-masked NaN-safety, 128K bottom-right alignment) + AOTI export; 1 gated-skip. - e2e: exported gemma4_31b .pte, ran 512/2K/8K/32K/127K (cuda_graph, temp 0, ignore_eos, 512 decode); table above. Verified `.ptd` = 1 weights blob (~26GB); baked configs = prefill `BM32/BN32/w4`, decode `BN32/w4/s2`. ### Known follow-ups (not in this PR) - `BLOCK_N=16` correctness bug (root cause unfixed; worked around by pruning). - Correctness-gated + GPU-adaptive autotune (only partial here); validate on 5090. - INT4 MLP W4A8 GEMV dominates decode (~76%) — separate effort.
1 parent 7e0151e commit 6021a58

2 files changed

Lines changed: 45 additions & 156 deletions

File tree

backends/cuda/triton/kernels/tq4_sdpa.py

Lines changed: 40 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _tq4_sdpa_fwd_kernel_body(
194194
# causal mask); otherwise the full kv_len bound is kept, which is safe for an
195195
# arbitrary mask.
196196
loop_end = kv_len
197-
if MASK_IS_CAUSAL:
197+
if MASK_IS_CAUSAL or IS_CAUSAL:
198198
max_q_pos = (kv_len - Lq) + tl.max(seq_pos)
199199
loop_end = tl.minimum(kv_len, max_q_pos + 1)
200200

@@ -227,7 +227,12 @@ def _tq4_sdpa_fwd_kernel_body(
227227
qk = tl.where(mask_block, qk, float("-inf"))
228228

229229
if IS_CAUSAL:
230-
causal = offs_n[None, :] > seq_pos[:, None]
230+
# Absolute causal-offset: a query row's KV position is
231+
# (kv_len - Lq) + seq_pos, correct for chunked prefill (Lq < kv_len).
232+
# For the square is_causal case (kv_len == Lq) it reduces to
233+
# offs_n > seq_pos. This lets a caller that guarantees a standard
234+
# causal mask skip the materialized mask read entirely.
235+
causal = offs_n[None, :] > (kv_len - Lq) + seq_pos[:, None]
231236
qk = tl.where(causal, float("-inf"), qk)
232237

233238
qk = tl.where(kv_valid[None, :], qk, float("-inf"))
@@ -283,143 +288,25 @@ def _tq4_sdpa_fwd_kernel_body(
283288

284289

285290
# ---------------------------------------------------------------------------
286-
# Autotuned kernel wrappers (M64 and M32)
291+
# Autotuned prefill kernel (single, no-spill)
287292
# ---------------------------------------------------------------------------
288293

289294

290295
@triton.autotune(
291296
configs=[
292-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
293-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3),
294-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
295-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3),
296-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2),
297-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2),
298-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3),
299-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2),
300-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3),
301-
],
302-
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
303-
)
304-
@triton.jit
305-
def _tq4_sdpa_fwd_kernel_m64(
306-
Q_ptr,
307-
KP_ptr,
308-
KN_ptr,
309-
VP_ptr,
310-
VN_ptr,
311-
LUT_hi_ptr,
312-
LUT_lo_ptr,
313-
Mask_ptr,
314-
O_ptr,
315-
KV_LEN_ptr,
316-
B,
317-
H_grid,
318-
Lq,
319-
Lk,
320-
stride_qb,
321-
stride_qh,
322-
stride_qm,
323-
stride_qd,
324-
stride_kpb,
325-
stride_kph,
326-
stride_kpn,
327-
stride_kpd,
328-
stride_knb,
329-
stride_knh,
330-
stride_knn,
331-
stride_vpb,
332-
stride_vph,
333-
stride_vpn,
334-
stride_vpd,
335-
stride_vnb,
336-
stride_vnh,
337-
stride_vnn,
338-
stride_ob,
339-
stride_oh,
340-
stride_om,
341-
stride_od,
342-
stride_mb,
343-
stride_mq,
344-
stride_mk,
345-
sm_scale: tl.float32,
346-
HAS_MASK: tl.constexpr,
347-
IS_CAUSAL: tl.constexpr,
348-
HAS_KV_LEN: tl.constexpr,
349-
MASK_IS_CAUSAL: tl.constexpr,
350-
HEAD_DIM: tl.constexpr,
351-
HALF_D: tl.constexpr,
352-
NUM_GROUPS: tl.constexpr,
353-
PACK_GQA: tl.constexpr,
354-
BLOCK_M: tl.constexpr,
355-
BLOCK_N: tl.constexpr,
356-
):
357-
_tq4_sdpa_fwd_kernel_body(
358-
Q_ptr,
359-
KP_ptr,
360-
KN_ptr,
361-
VP_ptr,
362-
VN_ptr,
363-
LUT_hi_ptr,
364-
LUT_lo_ptr,
365-
Mask_ptr,
366-
O_ptr,
367-
KV_LEN_ptr,
368-
B,
369-
H_grid,
370-
Lq,
371-
Lk,
372-
stride_qb,
373-
stride_qh,
374-
stride_qm,
375-
stride_qd,
376-
stride_kpb,
377-
stride_kph,
378-
stride_kpn,
379-
stride_kpd,
380-
stride_knb,
381-
stride_knh,
382-
stride_knn,
383-
stride_vpb,
384-
stride_vph,
385-
stride_vpn,
386-
stride_vpd,
387-
stride_vnb,
388-
stride_vnh,
389-
stride_vnn,
390-
stride_ob,
391-
stride_oh,
392-
stride_om,
393-
stride_od,
394-
stride_mb,
395-
stride_mq,
396-
stride_mk,
397-
sm_scale,
398-
HAS_MASK=HAS_MASK,
399-
IS_CAUSAL=IS_CAUSAL,
400-
HAS_KV_LEN=HAS_KV_LEN,
401-
MASK_IS_CAUSAL=MASK_IS_CAUSAL,
402-
BLOCK_M=BLOCK_M,
403-
BLOCK_N=BLOCK_N,
404-
HEAD_DIM=HEAD_DIM,
405-
HALF_D=HALF_D,
406-
NUM_GROUPS=NUM_GROUPS,
407-
PACK_GQA=PACK_GQA,
408-
)
409-
410-
411-
@triton.autotune(
412-
configs=[
413-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
414-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
415-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2),
297+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=2),
298+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3),
416299
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2),
417-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3),
300+
# Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
301+
# correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
302+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=2),
303+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=4),
304+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=3),
418305
],
419306
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
420307
)
421308
@triton.jit
422-
def _tq4_sdpa_fwd_kernel_m32(
309+
def _tq4_sdpa_prefill_kernel(
423310
Q_ptr,
424311
KP_ptr,
425312
KN_ptr,
@@ -570,15 +457,7 @@ def _launch_tq4_kernel(
570457
def grid(meta):
571458
return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid)
572459

573-
total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid)
574-
threshold = 4 * 84
575-
kernel = (
576-
_tq4_sdpa_fwd_kernel_m32
577-
if total_ctas_m64 < threshold
578-
else _tq4_sdpa_fwd_kernel_m64
579-
)
580-
581-
wrap_triton(kernel)[grid](
460+
wrap_triton(_tq4_sdpa_prefill_kernel)[grid](
582461
q_rot,
583462
k_packed,
584463
k_norms,
@@ -845,6 +724,19 @@ def tq4_sdpa(
845724
pack_gqa,
846725
)
847726
else:
727+
# Prefill path (N_Q > 1, plus the rare N_Q==1 && N_KV<256 fallthrough).
728+
# When the caller guarantees a standard causal mask AND kv_len is known
729+
# (MASK_IS_CAUSAL), use the kernel's analytic absolute causal-offset and
730+
# skip loading the materialized mask — numerically identical, no mask HBM
731+
# traffic. Causal is then applied via IS_CAUSAL (which also drives the
732+
# per-tile loop-end clamp), so MASK_IS_CAUSAL is passed False to the
733+
# launcher. Otherwise honor the explicit mask / is_causal as-is.
734+
if MASK_IS_CAUSAL:
735+
prefill_has_mask = False
736+
prefill_is_causal = True
737+
else:
738+
prefill_has_mask = HAS_MASK
739+
prefill_is_causal = is_causal
848740
_launch_tq4_kernel(
849741
q_rot,
850742
k_packed,
@@ -863,13 +755,13 @@ def tq4_sdpa(
863755
N_KV,
864756
D,
865757
sm_scale,
866-
HAS_MASK,
758+
prefill_has_mask,
867759
HAS_KV_LEN,
868-
MASK_IS_CAUSAL,
760+
False,
869761
stride_mb,
870762
stride_mq,
871763
stride_mk,
872-
is_causal,
764+
prefill_is_causal,
873765
num_groups,
874766
pack_gqa,
875767
)
@@ -889,17 +781,14 @@ def tq4_sdpa(
889781

890782
@triton.autotune(
891783
configs=[
892-
triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
893-
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
894-
triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
895-
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
896-
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2),
897-
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
898-
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=2),
899-
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=3),
784+
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=2),
785+
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=3),
900786
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2),
901-
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2),
902-
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2),
787+
# Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
788+
# correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
789+
triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=2),
790+
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2),
791+
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=3),
903792
],
904793
key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK", "PACK_GQA"],
905794
)

examples/models/gemma4_31b/cuda_source_transformations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def _turboquant_attention_forward(
5252
5353
Mirrors the default forward up to (and including) RoPE; only the
5454
cache update and SDPA call differ.
55+
56+
NOTE: ``attn_mask`` is unused here and will be reconstucted in
57+
the kernel to save data transfer, but is passed to the default forward
5558
"""
5659
B, T, _ = x.shape
5760

@@ -94,9 +97,6 @@ def _turboquant_attention_forward(
9497
# step (catastrophic at 128k: ~2.7 tok/s decode vs ~37+ when bounded).
9598
kv_len = input_pos[0] + input_pos.shape[0]
9699

97-
# ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's
98-
# default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the
99-
# 1/sqrt(d) factor into trained weights.
100100
y = torch.ops.triton.tq4_sdpa(
101101
q,
102102
k_packed,
@@ -105,8 +105,8 @@ def _turboquant_attention_forward(
105105
v_norms,
106106
self.kv_cache.centroids,
107107
self.kv_cache.rotation,
108-
attn_mask,
109-
False, # is_causal: attn_mask already encodes causal masking
108+
None, # reconstuct attention mask in the kernel to save data transfer
109+
False, # is_causal: needs L_q==L_kv; causal comes from mask_is_causal
110110
self.scaling,
111111
kv_len,
112112
True, # mask_is_causal: Gemma full-attention mask is standard causal

0 commit comments

Comments
 (0)