@@ -194,7 +194,7 @@ def _tq4_sdpa_fwd_kernel_body(
194194 # causal mask); otherwise the full kv_len bound is kept, which is safe for an
195195 # arbitrary mask.
196196 loop_end = kv_len
197- if MASK_IS_CAUSAL :
197+ if MASK_IS_CAUSAL or IS_CAUSAL :
198198 max_q_pos = (kv_len - Lq ) + tl .max (seq_pos )
199199 loop_end = tl .minimum (kv_len , max_q_pos + 1 )
200200
@@ -227,7 +227,12 @@ def _tq4_sdpa_fwd_kernel_body(
227227 qk = tl .where (mask_block , qk , float ("-inf" ))
228228
229229 if IS_CAUSAL :
230- causal = offs_n [None , :] > seq_pos [:, None ]
230+ # Absolute causal-offset: a query row's KV position is
231+ # (kv_len - Lq) + seq_pos, correct for chunked prefill (Lq < kv_len).
232+ # For the square is_causal case (kv_len == Lq) it reduces to
233+ # offs_n > seq_pos. This lets a caller that guarantees a standard
234+ # causal mask skip the materialized mask read entirely.
235+ causal = offs_n [None , :] > (kv_len - Lq ) + seq_pos [:, None ]
231236 qk = tl .where (causal , float ("-inf" ), qk )
232237
233238 qk = tl .where (kv_valid [None , :], qk , float ("-inf" ))
@@ -283,143 +288,27 @@ def _tq4_sdpa_fwd_kernel_body(
283288
284289
285290# ---------------------------------------------------------------------------
286- # Autotuned kernel wrappers (M64 and M32 )
291+ # Autotuned prefill kernel (single, no-spill )
287292# ---------------------------------------------------------------------------
288293
289294
290295@triton .autotune (
291296 configs = [
292- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 64 }, num_warps = 4 , num_stages = 2 ),
293- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 128 }, num_warps = 4 , num_stages = 3 ),
294- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 128 }, num_warps = 8 , num_stages = 2 ),
295- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 256 }, num_warps = 8 , num_stages = 3 ),
296- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 32 }, num_warps = 4 , num_stages = 2 ),
297- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 16 }, num_warps = 4 , num_stages = 2 ),
298- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 16 }, num_warps = 4 , num_stages = 3 ),
299- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 16 }, num_warps = 8 , num_stages = 2 ),
300- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 16 }, num_warps = 8 , num_stages = 3 ),
301- ],
302- key = ["Lq" , "Lk" , "HEAD_DIM" , "HAS_MASK" , "IS_CAUSAL" , "NUM_GROUPS" , "PACK_GQA" ],
303- )
304- @triton .jit
305- def _tq4_sdpa_fwd_kernel_m64 (
306- Q_ptr ,
307- KP_ptr ,
308- KN_ptr ,
309- VP_ptr ,
310- VN_ptr ,
311- LUT_hi_ptr ,
312- LUT_lo_ptr ,
313- Mask_ptr ,
314- O_ptr ,
315- KV_LEN_ptr ,
316- B ,
317- H_grid ,
318- Lq ,
319- Lk ,
320- stride_qb ,
321- stride_qh ,
322- stride_qm ,
323- stride_qd ,
324- stride_kpb ,
325- stride_kph ,
326- stride_kpn ,
327- stride_kpd ,
328- stride_knb ,
329- stride_knh ,
330- stride_knn ,
331- stride_vpb ,
332- stride_vph ,
333- stride_vpn ,
334- stride_vpd ,
335- stride_vnb ,
336- stride_vnh ,
337- stride_vnn ,
338- stride_ob ,
339- stride_oh ,
340- stride_om ,
341- stride_od ,
342- stride_mb ,
343- stride_mq ,
344- stride_mk ,
345- sm_scale : tl .float32 ,
346- HAS_MASK : tl .constexpr ,
347- IS_CAUSAL : tl .constexpr ,
348- HAS_KV_LEN : tl .constexpr ,
349- MASK_IS_CAUSAL : tl .constexpr ,
350- HEAD_DIM : tl .constexpr ,
351- HALF_D : tl .constexpr ,
352- NUM_GROUPS : tl .constexpr ,
353- PACK_GQA : tl .constexpr ,
354- BLOCK_M : tl .constexpr ,
355- BLOCK_N : tl .constexpr ,
356- ):
357- _tq4_sdpa_fwd_kernel_body (
358- Q_ptr ,
359- KP_ptr ,
360- KN_ptr ,
361- VP_ptr ,
362- VN_ptr ,
363- LUT_hi_ptr ,
364- LUT_lo_ptr ,
365- Mask_ptr ,
366- O_ptr ,
367- KV_LEN_ptr ,
368- B ,
369- H_grid ,
370- Lq ,
371- Lk ,
372- stride_qb ,
373- stride_qh ,
374- stride_qm ,
375- stride_qd ,
376- stride_kpb ,
377- stride_kph ,
378- stride_kpn ,
379- stride_kpd ,
380- stride_knb ,
381- stride_knh ,
382- stride_knn ,
383- stride_vpb ,
384- stride_vph ,
385- stride_vpn ,
386- stride_vpd ,
387- stride_vnb ,
388- stride_vnh ,
389- stride_vnn ,
390- stride_ob ,
391- stride_oh ,
392- stride_om ,
393- stride_od ,
394- stride_mb ,
395- stride_mq ,
396- stride_mk ,
397- sm_scale ,
398- HAS_MASK = HAS_MASK ,
399- IS_CAUSAL = IS_CAUSAL ,
400- HAS_KV_LEN = HAS_KV_LEN ,
401- MASK_IS_CAUSAL = MASK_IS_CAUSAL ,
402- BLOCK_M = BLOCK_M ,
403- BLOCK_N = BLOCK_N ,
404- HEAD_DIM = HEAD_DIM ,
405- HALF_D = HALF_D ,
406- NUM_GROUPS = NUM_GROUPS ,
407- PACK_GQA = PACK_GQA ,
408- )
409-
410-
411- @triton .autotune (
412- configs = [
413- triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 64 }, num_warps = 4 , num_stages = 2 ),
414- triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 128 }, num_warps = 4 , num_stages = 2 ),
415- triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 256 }, num_warps = 4 , num_stages = 2 ),
297+ # No-spill prefill configs, pruned to the profiled-optimal set for the
298+ # gemma4 global shape (heavy-shape optimum = BLOCK_M=32/BLOCK_N=32/w4/s2).
299+ # BLOCK_M=32 keeps the fp32 acc[BLOCK_M, HEAD_DIM] in registers (BLOCK_M=64
300+ # at HEAD_DIM=512 = 128 KB/CTA spills to local memory) and BLOCK_N<=64
301+ # keeps the staged decompressed K/V tile within the A100 SMEM budget.
302+ # BLOCK_M=16 / BLOCK_N=16 configs were pruned (slower; BLOCK_N=16 also
303+ # measured low cosine ~0.79-0.93 at this shape).
304+ triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 64 }, num_warps = 8 , num_stages = 2 ),
305+ triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 32 }, num_warps = 4 , num_stages = 3 ),
416306 triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 32 }, num_warps = 4 , num_stages = 2 ),
417- triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 16 }, num_warps = 4 , num_stages = 3 ),
418307 ],
419308 key = ["Lq" , "Lk" , "HEAD_DIM" , "HAS_MASK" , "IS_CAUSAL" , "NUM_GROUPS" , "PACK_GQA" ],
420309)
421310@triton .jit
422- def _tq4_sdpa_fwd_kernel_m32 (
311+ def _tq4_sdpa_prefill_kernel (
423312 Q_ptr ,
424313 KP_ptr ,
425314 KN_ptr ,
@@ -570,15 +459,7 @@ def _launch_tq4_kernel(
570459 def grid (meta ):
571460 return (triton .cdiv (Lq_packed , meta ["BLOCK_M" ]), B * H_grid )
572461
573- total_ctas_m64 = ((Lq_packed + 63 ) // 64 ) * (B * H_grid )
574- threshold = 4 * 84
575- kernel = (
576- _tq4_sdpa_fwd_kernel_m32
577- if total_ctas_m64 < threshold
578- else _tq4_sdpa_fwd_kernel_m64
579- )
580-
581- wrap_triton (kernel )[grid ](
462+ wrap_triton (_tq4_sdpa_prefill_kernel )[grid ](
582463 q_rot ,
583464 k_packed ,
584465 k_norms ,
@@ -845,6 +726,19 @@ def tq4_sdpa(
845726 pack_gqa ,
846727 )
847728 else :
729+ # Prefill path (N_Q > 1, plus the rare N_Q==1 && N_KV<256 fallthrough).
730+ # When the caller guarantees a standard causal mask AND kv_len is known
731+ # (MASK_IS_CAUSAL), use the kernel's analytic absolute causal-offset and
732+ # skip loading the materialized mask — numerically identical, no mask HBM
733+ # traffic. Causal is then applied via IS_CAUSAL (which also drives the
734+ # per-tile loop-end clamp), so MASK_IS_CAUSAL is passed False to the
735+ # launcher. Otherwise honor the explicit mask / is_causal as-is.
736+ if MASK_IS_CAUSAL :
737+ prefill_has_mask = False
738+ prefill_is_causal = True
739+ else :
740+ prefill_has_mask = HAS_MASK
741+ prefill_is_causal = is_causal
848742 _launch_tq4_kernel (
849743 q_rot ,
850744 k_packed ,
@@ -863,13 +757,13 @@ def tq4_sdpa(
863757 N_KV ,
864758 D ,
865759 sm_scale ,
866- HAS_MASK ,
760+ prefill_has_mask ,
867761 HAS_KV_LEN ,
868- MASK_IS_CAUSAL ,
762+ False ,
869763 stride_mb ,
870764 stride_mq ,
871765 stride_mk ,
872- is_causal ,
766+ prefill_is_causal ,
873767 num_groups ,
874768 pack_gqa ,
875769 )
@@ -889,17 +783,17 @@ def tq4_sdpa(
889783
890784@triton .autotune (
891785 configs = [
892- triton .Config ({"BLOCK_N" : 32 }, num_warps = 2 , num_stages = 1 ),
893- triton .Config ({"BLOCK_N" : 32 }, num_warps = 4 , num_stages = 1 ),
894- triton .Config ({"BLOCK_N" : 64 }, num_warps = 2 , num_stages = 1 ),
895- triton .Config ({"BLOCK_N" : 64 }, num_warps = 4 , num_stages = 1 ),
896- triton .Config ({"BLOCK_N" : 64 }, num_warps = 4 , num_stages = 2 ),
897- triton .Config ({"BLOCK_N" : 128 }, num_warps = 4 , num_stages = 1 ),
898- triton .Config ({"BLOCK_N" : 128 }, num_warps = 4 , num_stages = 2 ),
899- triton .Config ({"BLOCK_N" : 128 }, num_warps = 4 , num_stages = 3 ),
786+ # Split-K decode configs, curated to the profiled-optimal set so the
787+ # HAS_MASK=False specialization (decode passes attn_mask=None too, for the
788+ # AOTI weights-blob dedup) bakes a good config: BLOCK_N=32/w4/s2 is the
789+ # primary optimum (964us@127K, 344us@32K), BLOCK_N=64/w8/s3 wins at 127K
790+ # (914us), BLOCK_N=128/w8/s2 is a safe fallback. Other configs were pruned:
791+ # BLOCK_N=64/w2/s1 (12.8ms), 128/w4/s{1,2,3} (up to 9.4ms) and 32/w2/s1 are
792+ # catastrophic for HAS_MASK=False; the rest were not measured-optimal and
793+ # are dropped so AOTI cannot bake a slow one (no autotune lottery).
794+ triton .Config ({"BLOCK_N" : 32 }, num_warps = 4 , num_stages = 2 ),
795+ triton .Config ({"BLOCK_N" : 64 }, num_warps = 8 , num_stages = 3 ),
900796 triton .Config ({"BLOCK_N" : 128 }, num_warps = 8 , num_stages = 2 ),
901- triton .Config ({"BLOCK_N" : 256 }, num_warps = 4 , num_stages = 2 ),
902- triton .Config ({"BLOCK_N" : 256 }, num_warps = 8 , num_stages = 2 ),
903797 ],
904798 key = ["Lk" , "HEAD_DIM" , "NUM_GROUPS" , "HAS_MASK" , "PACK_GQA" ],
905799)
0 commit comments