diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index 7a41eaf92c1..427f2eef4eb 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -194,7 +194,7 @@ def _tq4_sdpa_fwd_kernel_body( # causal mask); otherwise the full kv_len bound is kept, which is safe for an # arbitrary mask. loop_end = kv_len - if MASK_IS_CAUSAL: + if MASK_IS_CAUSAL or IS_CAUSAL: max_q_pos = (kv_len - Lq) + tl.max(seq_pos) loop_end = tl.minimum(kv_len, max_q_pos + 1) @@ -227,7 +227,12 @@ def _tq4_sdpa_fwd_kernel_body( qk = tl.where(mask_block, qk, float("-inf")) if IS_CAUSAL: - causal = offs_n[None, :] > seq_pos[:, None] + # Absolute causal-offset: a query row's KV position is + # (kv_len - Lq) + seq_pos, correct for chunked prefill (Lq < kv_len). + # For the square is_causal case (kv_len == Lq) it reduces to + # offs_n > seq_pos. This lets a caller that guarantees a standard + # causal mask skip the materialized mask read entirely. + causal = offs_n[None, :] > (kv_len - Lq) + seq_pos[:, None] qk = tl.where(causal, float("-inf"), qk) qk = tl.where(kv_valid[None, :], qk, float("-inf")) @@ -283,143 +288,25 @@ def _tq4_sdpa_fwd_kernel_body( # --------------------------------------------------------------------------- -# Autotuned kernel wrappers (M64 and M32) +# Autotuned prefill kernel (single, no-spill) # --------------------------------------------------------------------------- @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3), - ], - key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], -) -@triton.jit -def _tq4_sdpa_fwd_kernel_m64( - Q_ptr, - KP_ptr, - KN_ptr, - VP_ptr, - VN_ptr, - LUT_hi_ptr, - LUT_lo_ptr, - Mask_ptr, - O_ptr, - KV_LEN_ptr, - B, - H_grid, - Lq, - Lk, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kpb, - stride_kph, - stride_kpn, - stride_kpd, - stride_knb, - stride_knh, - stride_knn, - stride_vpb, - stride_vph, - stride_vpn, - stride_vpd, - stride_vnb, - stride_vnh, - stride_vnn, - stride_ob, - stride_oh, - stride_om, - stride_od, - stride_mb, - stride_mq, - stride_mk, - sm_scale: tl.float32, - HAS_MASK: tl.constexpr, - IS_CAUSAL: tl.constexpr, - HAS_KV_LEN: tl.constexpr, - MASK_IS_CAUSAL: tl.constexpr, - HEAD_DIM: tl.constexpr, - HALF_D: tl.constexpr, - NUM_GROUPS: tl.constexpr, - PACK_GQA: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - _tq4_sdpa_fwd_kernel_body( - Q_ptr, - KP_ptr, - KN_ptr, - VP_ptr, - VN_ptr, - LUT_hi_ptr, - LUT_lo_ptr, - Mask_ptr, - O_ptr, - KV_LEN_ptr, - B, - H_grid, - Lq, - Lk, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kpb, - stride_kph, - stride_kpn, - stride_kpd, - stride_knb, - stride_knh, - stride_knn, - stride_vpb, - stride_vph, - stride_vpn, - stride_vpd, - stride_vnb, - stride_vnh, - stride_vnn, - stride_ob, - stride_oh, - stride_om, - stride_od, - stride_mb, - stride_mq, - stride_mk, - sm_scale, - HAS_MASK=HAS_MASK, - IS_CAUSAL=IS_CAUSAL, - HAS_KV_LEN=HAS_KV_LEN, - MASK_IS_CAUSAL=MASK_IS_CAUSAL, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - HEAD_DIM=HEAD_DIM, - HALF_D=HALF_D, - NUM_GROUPS=NUM_GROUPS, - PACK_GQA=PACK_GQA, - ) - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3), triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3), + # Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090); + # correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong). + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=3), ], key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) @triton.jit -def _tq4_sdpa_fwd_kernel_m32( +def _tq4_sdpa_prefill_kernel( Q_ptr, KP_ptr, KN_ptr, @@ -570,15 +457,7 @@ def _launch_tq4_kernel( def grid(meta): return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid) - total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid) - threshold = 4 * 84 - kernel = ( - _tq4_sdpa_fwd_kernel_m32 - if total_ctas_m64 < threshold - else _tq4_sdpa_fwd_kernel_m64 - ) - - wrap_triton(kernel)[grid]( + wrap_triton(_tq4_sdpa_prefill_kernel)[grid]( q_rot, k_packed, k_norms, @@ -845,6 +724,19 @@ def tq4_sdpa( pack_gqa, ) else: + # Prefill path (N_Q > 1, plus the rare N_Q==1 && N_KV<256 fallthrough). + # When the caller guarantees a standard causal mask AND kv_len is known + # (MASK_IS_CAUSAL), use the kernel's analytic absolute causal-offset and + # skip loading the materialized mask — numerically identical, no mask HBM + # traffic. Causal is then applied via IS_CAUSAL (which also drives the + # per-tile loop-end clamp), so MASK_IS_CAUSAL is passed False to the + # launcher. Otherwise honor the explicit mask / is_causal as-is. + if MASK_IS_CAUSAL: + prefill_has_mask = False + prefill_is_causal = True + else: + prefill_has_mask = HAS_MASK + prefill_is_causal = is_causal _launch_tq4_kernel( q_rot, k_packed, @@ -863,13 +755,13 @@ def tq4_sdpa( N_KV, D, sm_scale, - HAS_MASK, + prefill_has_mask, HAS_KV_LEN, - MASK_IS_CAUSAL, + False, stride_mb, stride_mq, stride_mk, - is_causal, + prefill_is_causal, num_groups, pack_gqa, ) @@ -889,17 +781,14 @@ def tq4_sdpa( @triton.autotune( configs=[ - triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1), - triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1), - triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1), - triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1), - triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1), - triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=3), triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2), + # Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090); + # correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong). + triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=3), ], key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK", "PACK_GQA"], ) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 635161390d7..132ddb33f1d 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -52,6 +52,9 @@ def _turboquant_attention_forward( Mirrors the default forward up to (and including) RoPE; only the cache update and SDPA call differ. + + NOTE: ``attn_mask`` is unused here and will be reconstucted in + the kernel to save data transfer, but is passed to the default forward """ B, T, _ = x.shape @@ -94,9 +97,6 @@ def _turboquant_attention_forward( # step (catastrophic at 128k: ~2.7 tok/s decode vs ~37+ when bounded). kv_len = input_pos[0] + input_pos.shape[0] - # ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's - # default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the - # 1/sqrt(d) factor into trained weights. y = torch.ops.triton.tq4_sdpa( q, k_packed, @@ -105,8 +105,8 @@ def _turboquant_attention_forward( v_norms, self.kv_cache.centroids, self.kv_cache.rotation, - attn_mask, - False, # is_causal: attn_mask already encodes causal masking + None, # reconstuct attention mask in the kernel to save data transfer + False, # is_causal: needs L_q==L_kv; causal comes from mask_is_causal self.scaling, kv_len, True, # mask_is_causal: Gemma full-attention mask is standard causal