Skip to content

Commit 2a93d53

Browse files
committed
gemma4_31b TQ4 SDPA: add 5090-feasible autotune configs + comment updates
1 parent 41371d0 commit 2a93d53

2 files changed

Lines changed: 14 additions & 25 deletions

File tree

backends/cuda/triton/kernels/tq4_sdpa.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -294,16 +294,14 @@ def _tq4_sdpa_fwd_kernel_body(
294294

295295
@triton.autotune(
296296
configs=[
297-
# No-spill prefill configs, pruned to the profiled-optimal set for the
298-
# gemma4 global shape (heavy-shape optimum = BLOCK_M=32/BLOCK_N=32/w4/s2).
299-
# BLOCK_M=32 keeps the fp32 acc[BLOCK_M, HEAD_DIM] in registers (BLOCK_M=64
300-
# at HEAD_DIM=512 = 128 KB/CTA spills to local memory) and BLOCK_N<=64
301-
# keeps the staged decompressed K/V tile within the A100 SMEM budget.
302-
# BLOCK_M=16 / BLOCK_N=16 configs were pruned (slower; BLOCK_N=16 also
303-
# measured low cosine ~0.79-0.93 at this shape).
304297
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=2),
305298
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3),
306299
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2),
300+
# Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
301+
# correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
302+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=2),
303+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=4),
304+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=3),
307305
],
308306
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
309307
)
@@ -783,17 +781,14 @@ def tq4_sdpa(
783781

784782
@triton.autotune(
785783
configs=[
786-
# Split-K decode configs, curated to the profiled-optimal set so the
787-
# HAS_MASK=False specialization (decode passes attn_mask=None too, for the
788-
# AOTI weights-blob dedup) bakes a good config: BLOCK_N=32/w4/s2 is the
789-
# primary optimum (964us@127K, 344us@32K), BLOCK_N=64/w8/s3 wins at 127K
790-
# (914us), BLOCK_N=128/w8/s2 is a safe fallback. Other configs were pruned:
791-
# BLOCK_N=64/w2/s1 (12.8ms), 128/w4/s{1,2,3} (up to 9.4ms) and 32/w2/s1 are
792-
# catastrophic for HAS_MASK=False; the rest were not measured-optimal and
793-
# are dropped so AOTI cannot bake a slow one (no autotune lottery).
794784
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=2),
795785
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=3),
796786
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2),
787+
# Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
788+
# correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
789+
triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=2),
790+
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2),
791+
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=3),
797792
],
798793
key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK", "PACK_GQA"],
799794
)

examples/models/gemma4_31b/cuda_source_transformations.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def _turboquant_attention_forward(
5252
5353
Mirrors the default forward up to (and including) RoPE; only the
5454
cache update and SDPA call differ.
55+
56+
NOTE: ``attn_mask`` is unused here and will be reconstucted in
57+
the kernel to save data transfer, but is passed to the default forward
5558
"""
5659
B, T, _ = x.shape
5760

@@ -94,15 +97,6 @@ def _turboquant_attention_forward(
9497
# step (catastrophic at 128k: ~2.7 tok/s decode vs ~37+ when bounded).
9598
kv_len = input_pos[0] + input_pos.shape[0]
9699

97-
# attn_mask=None for BOTH prefill and decode: tq4_sdpa applies causal masking
98-
# analytically (mask_is_causal + kv_len, absolute causal-offset), so the SDPA
99-
# call is identical across the two exported methods — AOTI dedups the shared
100-
# weights blob (~26 GB). Prefill takes the no-spill analytic path; decode takes
101-
# split-K with HAS_MASK=False, whose autotune list is curated (tq4_sdpa.py) to
102-
# the profiled-optimal BLOCK_N configs, so HAS_MASK=False does not regress
103-
# decode. ``scale=self.scaling`` (= 1.0 for Gemma 4) overrides tq4_sdpa's
104-
# 1/sqrt(D) default (Gemma's QK-norm folded that factor into the weights).
105-
sdpa_attn_mask = None
106100
y = torch.ops.triton.tq4_sdpa(
107101
q,
108102
k_packed,
@@ -111,7 +105,7 @@ def _turboquant_attention_forward(
111105
v_norms,
112106
self.kv_cache.centroids,
113107
self.kv_cache.rotation,
114-
sdpa_attn_mask,
108+
None, # reconstuct attention mask in the kernel to save data transfer
115109
False, # is_causal: needs L_q==L_kv; causal comes from mask_is_causal
116110
self.scaling,
117111
kv_len,

0 commit comments

Comments
 (0)