@@ -294,16 +294,14 @@ def _tq4_sdpa_fwd_kernel_body(
294294
295295@triton .autotune (
296296 configs = [
297- # No-spill prefill configs, pruned to the profiled-optimal set for the
298- # gemma4 global shape (heavy-shape optimum = BLOCK_M=32/BLOCK_N=32/w4/s2).
299- # BLOCK_M=32 keeps the fp32 acc[BLOCK_M, HEAD_DIM] in registers (BLOCK_M=64
300- # at HEAD_DIM=512 = 128 KB/CTA spills to local memory) and BLOCK_N<=64
301- # keeps the staged decompressed K/V tile within the A100 SMEM budget.
302- # BLOCK_M=16 / BLOCK_N=16 configs were pruned (slower; BLOCK_N=16 also
303- # measured low cosine ~0.79-0.93 at this shape).
304297 triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 64 }, num_warps = 8 , num_stages = 2 ),
305298 triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 32 }, num_warps = 4 , num_stages = 3 ),
306299 triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 32 }, num_warps = 4 , num_stages = 2 ),
300+ # Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
301+ # correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
302+ triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 32 }, num_warps = 8 , num_stages = 2 ),
303+ triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 32 }, num_warps = 4 , num_stages = 4 ),
304+ triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 64 }, num_warps = 8 , num_stages = 3 ),
307305 ],
308306 key = ["Lq" , "Lk" , "HEAD_DIM" , "HAS_MASK" , "IS_CAUSAL" , "NUM_GROUPS" , "PACK_GQA" ],
309307)
@@ -783,17 +781,14 @@ def tq4_sdpa(
783781
784782@triton .autotune (
785783 configs = [
786- # Split-K decode configs, curated to the profiled-optimal set so the
787- # HAS_MASK=False specialization (decode passes attn_mask=None too, for the
788- # AOTI weights-blob dedup) bakes a good config: BLOCK_N=32/w4/s2 is the
789- # primary optimum (964us@127K, 344us@32K), BLOCK_N=64/w8/s3 wins at 127K
790- # (914us), BLOCK_N=128/w8/s2 is a safe fallback. Other configs were pruned:
791- # BLOCK_N=64/w2/s1 (12.8ms), 128/w4/s{1,2,3} (up to 9.4ms) and 32/w2/s1 are
792- # catastrophic for HAS_MASK=False; the rest were not measured-optimal and
793- # are dropped so AOTI cannot bake a slow one (no autotune lottery).
794784 triton .Config ({"BLOCK_N" : 32 }, num_warps = 4 , num_stages = 2 ),
795785 triton .Config ({"BLOCK_N" : 64 }, num_warps = 8 , num_stages = 3 ),
796786 triton .Config ({"BLOCK_N" : 128 }, num_warps = 8 , num_stages = 2 ),
787+ # Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
788+ # correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
789+ triton .Config ({"BLOCK_N" : 32 }, num_warps = 8 , num_stages = 2 ),
790+ triton .Config ({"BLOCK_N" : 64 }, num_warps = 4 , num_stages = 2 ),
791+ triton .Config ({"BLOCK_N" : 32 }, num_warps = 4 , num_stages = 3 ),
797792 ],
798793 key = ["Lk" , "HEAD_DIM" , "NUM_GROUPS" , "HAS_MASK" , "PACK_GQA" ],
799794)
0 commit comments