Skip to content

Commit 41371d0

Browse files
committed
gemma4_31b TQ4 SDPA: no-spill prefill + analytic causal + autotune retune
Prefill (global head_dim=512 TQ4 path): - Consolidate m64/m32 into one no-spill _tq4_sdpa_prefill_kernel (BLOCK_M<=32 so acc[BLOCK_M,512] fp32 stays in registers, no spill). - Absolute-offset analytic causal (offs_n > (kv_len-Lq)+seq_pos); kernel no longer reads a materialized causal mask. - Prune prefill autotune list to the heavy-kv optima; also removes BLOCK_N=16 configs that produced incorrect output. Decode (split-K): retune autotune list for the HAS_MASK=False path - add the profiled optima (BLOCK_N=32/w4/s2, BLOCK_N=64/w8/s3), drop configs that are catastrophic at HAS_MASK=False (BLOCK_N=64/w2, BLOCK_N=128/w4). Call-site passes attn_mask=None for BOTH prefill and decode so the two exported methods share one AOTI weights blob (avoids 2x .ptd weight duplication). e2e @127k (A100, vs same-tree baseline): prefill 543->715 t/s (1.32x), decode 34.1->34.5 t/s, peak 25.75 GB (zero extra memory). Smaller contexts beat llama on both prefill and decode.
1 parent 74c2a9d commit 41371d0

2 files changed

Lines changed: 56 additions & 156 deletions

File tree

backends/cuda/triton/kernels/tq4_sdpa.py

Lines changed: 45 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _tq4_sdpa_fwd_kernel_body(
194194
# causal mask); otherwise the full kv_len bound is kept, which is safe for an
195195
# arbitrary mask.
196196
loop_end = kv_len
197-
if MASK_IS_CAUSAL:
197+
if MASK_IS_CAUSAL or IS_CAUSAL:
198198
max_q_pos = (kv_len - Lq) + tl.max(seq_pos)
199199
loop_end = tl.minimum(kv_len, max_q_pos + 1)
200200

@@ -227,7 +227,12 @@ def _tq4_sdpa_fwd_kernel_body(
227227
qk = tl.where(mask_block, qk, float("-inf"))
228228

229229
if IS_CAUSAL:
230-
causal = offs_n[None, :] > seq_pos[:, None]
230+
# Absolute causal-offset: a query row's KV position is
231+
# (kv_len - Lq) + seq_pos, correct for chunked prefill (Lq < kv_len).
232+
# For the square is_causal case (kv_len == Lq) it reduces to
233+
# offs_n > seq_pos. This lets a caller that guarantees a standard
234+
# causal mask skip the materialized mask read entirely.
235+
causal = offs_n[None, :] > (kv_len - Lq) + seq_pos[:, None]
231236
qk = tl.where(causal, float("-inf"), qk)
232237

233238
qk = tl.where(kv_valid[None, :], qk, float("-inf"))
@@ -283,143 +288,27 @@ def _tq4_sdpa_fwd_kernel_body(
283288

284289

285290
# ---------------------------------------------------------------------------
286-
# Autotuned kernel wrappers (M64 and M32)
291+
# Autotuned prefill kernel (single, no-spill)
287292
# ---------------------------------------------------------------------------
288293

289294

290295
@triton.autotune(
291296
configs=[
292-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
293-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3),
294-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
295-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3),
296-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2),
297-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2),
298-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3),
299-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2),
300-
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3),
301-
],
302-
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
303-
)
304-
@triton.jit
305-
def _tq4_sdpa_fwd_kernel_m64(
306-
Q_ptr,
307-
KP_ptr,
308-
KN_ptr,
309-
VP_ptr,
310-
VN_ptr,
311-
LUT_hi_ptr,
312-
LUT_lo_ptr,
313-
Mask_ptr,
314-
O_ptr,
315-
KV_LEN_ptr,
316-
B,
317-
H_grid,
318-
Lq,
319-
Lk,
320-
stride_qb,
321-
stride_qh,
322-
stride_qm,
323-
stride_qd,
324-
stride_kpb,
325-
stride_kph,
326-
stride_kpn,
327-
stride_kpd,
328-
stride_knb,
329-
stride_knh,
330-
stride_knn,
331-
stride_vpb,
332-
stride_vph,
333-
stride_vpn,
334-
stride_vpd,
335-
stride_vnb,
336-
stride_vnh,
337-
stride_vnn,
338-
stride_ob,
339-
stride_oh,
340-
stride_om,
341-
stride_od,
342-
stride_mb,
343-
stride_mq,
344-
stride_mk,
345-
sm_scale: tl.float32,
346-
HAS_MASK: tl.constexpr,
347-
IS_CAUSAL: tl.constexpr,
348-
HAS_KV_LEN: tl.constexpr,
349-
MASK_IS_CAUSAL: tl.constexpr,
350-
HEAD_DIM: tl.constexpr,
351-
HALF_D: tl.constexpr,
352-
NUM_GROUPS: tl.constexpr,
353-
PACK_GQA: tl.constexpr,
354-
BLOCK_M: tl.constexpr,
355-
BLOCK_N: tl.constexpr,
356-
):
357-
_tq4_sdpa_fwd_kernel_body(
358-
Q_ptr,
359-
KP_ptr,
360-
KN_ptr,
361-
VP_ptr,
362-
VN_ptr,
363-
LUT_hi_ptr,
364-
LUT_lo_ptr,
365-
Mask_ptr,
366-
O_ptr,
367-
KV_LEN_ptr,
368-
B,
369-
H_grid,
370-
Lq,
371-
Lk,
372-
stride_qb,
373-
stride_qh,
374-
stride_qm,
375-
stride_qd,
376-
stride_kpb,
377-
stride_kph,
378-
stride_kpn,
379-
stride_kpd,
380-
stride_knb,
381-
stride_knh,
382-
stride_knn,
383-
stride_vpb,
384-
stride_vph,
385-
stride_vpn,
386-
stride_vpd,
387-
stride_vnb,
388-
stride_vnh,
389-
stride_vnn,
390-
stride_ob,
391-
stride_oh,
392-
stride_om,
393-
stride_od,
394-
stride_mb,
395-
stride_mq,
396-
stride_mk,
397-
sm_scale,
398-
HAS_MASK=HAS_MASK,
399-
IS_CAUSAL=IS_CAUSAL,
400-
HAS_KV_LEN=HAS_KV_LEN,
401-
MASK_IS_CAUSAL=MASK_IS_CAUSAL,
402-
BLOCK_M=BLOCK_M,
403-
BLOCK_N=BLOCK_N,
404-
HEAD_DIM=HEAD_DIM,
405-
HALF_D=HALF_D,
406-
NUM_GROUPS=NUM_GROUPS,
407-
PACK_GQA=PACK_GQA,
408-
)
409-
410-
411-
@triton.autotune(
412-
configs=[
413-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
414-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
415-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2),
297+
# No-spill prefill configs, pruned to the profiled-optimal set for the
298+
# gemma4 global shape (heavy-shape optimum = BLOCK_M=32/BLOCK_N=32/w4/s2).
299+
# BLOCK_M=32 keeps the fp32 acc[BLOCK_M, HEAD_DIM] in registers (BLOCK_M=64
300+
# at HEAD_DIM=512 = 128 KB/CTA spills to local memory) and BLOCK_N<=64
301+
# keeps the staged decompressed K/V tile within the A100 SMEM budget.
302+
# BLOCK_M=16 / BLOCK_N=16 configs were pruned (slower; BLOCK_N=16 also
303+
# measured low cosine ~0.79-0.93 at this shape).
304+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=2),
305+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3),
416306
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2),
417-
triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3),
418307
],
419308
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
420309
)
421310
@triton.jit
422-
def _tq4_sdpa_fwd_kernel_m32(
311+
def _tq4_sdpa_prefill_kernel(
423312
Q_ptr,
424313
KP_ptr,
425314
KN_ptr,
@@ -570,15 +459,7 @@ def _launch_tq4_kernel(
570459
def grid(meta):
571460
return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid)
572461

573-
total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid)
574-
threshold = 4 * 84
575-
kernel = (
576-
_tq4_sdpa_fwd_kernel_m32
577-
if total_ctas_m64 < threshold
578-
else _tq4_sdpa_fwd_kernel_m64
579-
)
580-
581-
wrap_triton(kernel)[grid](
462+
wrap_triton(_tq4_sdpa_prefill_kernel)[grid](
582463
q_rot,
583464
k_packed,
584465
k_norms,
@@ -845,6 +726,19 @@ def tq4_sdpa(
845726
pack_gqa,
846727
)
847728
else:
729+
# Prefill path (N_Q > 1, plus the rare N_Q==1 && N_KV<256 fallthrough).
730+
# When the caller guarantees a standard causal mask AND kv_len is known
731+
# (MASK_IS_CAUSAL), use the kernel's analytic absolute causal-offset and
732+
# skip loading the materialized mask — numerically identical, no mask HBM
733+
# traffic. Causal is then applied via IS_CAUSAL (which also drives the
734+
# per-tile loop-end clamp), so MASK_IS_CAUSAL is passed False to the
735+
# launcher. Otherwise honor the explicit mask / is_causal as-is.
736+
if MASK_IS_CAUSAL:
737+
prefill_has_mask = False
738+
prefill_is_causal = True
739+
else:
740+
prefill_has_mask = HAS_MASK
741+
prefill_is_causal = is_causal
848742
_launch_tq4_kernel(
849743
q_rot,
850744
k_packed,
@@ -863,13 +757,13 @@ def tq4_sdpa(
863757
N_KV,
864758
D,
865759
sm_scale,
866-
HAS_MASK,
760+
prefill_has_mask,
867761
HAS_KV_LEN,
868-
MASK_IS_CAUSAL,
762+
False,
869763
stride_mb,
870764
stride_mq,
871765
stride_mk,
872-
is_causal,
766+
prefill_is_causal,
873767
num_groups,
874768
pack_gqa,
875769
)
@@ -889,17 +783,17 @@ def tq4_sdpa(
889783

890784
@triton.autotune(
891785
configs=[
892-
triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
893-
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
894-
triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
895-
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
896-
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2),
897-
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
898-
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=2),
899-
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=3),
786+
# Split-K decode configs, curated to the profiled-optimal set so the
787+
# HAS_MASK=False specialization (decode passes attn_mask=None too, for the
788+
# AOTI weights-blob dedup) bakes a good config: BLOCK_N=32/w4/s2 is the
789+
# primary optimum (964us@127K, 344us@32K), BLOCK_N=64/w8/s3 wins at 127K
790+
# (914us), BLOCK_N=128/w8/s2 is a safe fallback. Other configs were pruned:
791+
# BLOCK_N=64/w2/s1 (12.8ms), 128/w4/s{1,2,3} (up to 9.4ms) and 32/w2/s1 are
792+
# catastrophic for HAS_MASK=False; the rest were not measured-optimal and
793+
# are dropped so AOTI cannot bake a slow one (no autotune lottery).
794+
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=2),
795+
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=3),
900796
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2),
901-
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2),
902-
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2),
903797
],
904798
key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK", "PACK_GQA"],
905799
)

examples/models/gemma4_31b/cuda_source_transformations.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,15 @@ def _turboquant_attention_forward(
9494
# step (catastrophic at 128k: ~2.7 tok/s decode vs ~37+ when bounded).
9595
kv_len = input_pos[0] + input_pos.shape[0]
9696

97-
# ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's
98-
# default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the
99-
# 1/sqrt(d) factor into trained weights.
97+
# attn_mask=None for BOTH prefill and decode: tq4_sdpa applies causal masking
98+
# analytically (mask_is_causal + kv_len, absolute causal-offset), so the SDPA
99+
# call is identical across the two exported methods — AOTI dedups the shared
100+
# weights blob (~26 GB). Prefill takes the no-spill analytic path; decode takes
101+
# split-K with HAS_MASK=False, whose autotune list is curated (tq4_sdpa.py) to
102+
# the profiled-optimal BLOCK_N configs, so HAS_MASK=False does not regress
103+
# decode. ``scale=self.scaling`` (= 1.0 for Gemma 4) overrides tq4_sdpa's
104+
# 1/sqrt(D) default (Gemma's QK-norm folded that factor into the weights).
105+
sdpa_attn_mask = None
100106
y = torch.ops.triton.tq4_sdpa(
101107
q,
102108
k_packed,
@@ -105,8 +111,8 @@ def _turboquant_attention_forward(
105111
v_norms,
106112
self.kv_cache.centroids,
107113
self.kv_cache.rotation,
108-
attn_mask,
109-
False, # is_causal: attn_mask already encodes causal masking
114+
sdpa_attn_mask,
115+
False, # is_causal: needs L_q==L_kv; causal comes from mask_is_causal
110116
self.scaling,
111117
kv_len,
112118
True, # mask_is_causal: Gemma full-attention mask is standard causal

0 commit comments

Comments
 (0)