Skip to content
191 changes: 40 additions & 151 deletions backends/cuda/triton/kernels/tq4_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _tq4_sdpa_fwd_kernel_body(
# causal mask); otherwise the full kv_len bound is kept, which is safe for an
# arbitrary mask.
loop_end = kv_len
if MASK_IS_CAUSAL:
if MASK_IS_CAUSAL or IS_CAUSAL:
max_q_pos = (kv_len - Lq) + tl.max(seq_pos)
loop_end = tl.minimum(kv_len, max_q_pos + 1)

Expand Down Expand Up @@ -227,7 +227,12 @@ def _tq4_sdpa_fwd_kernel_body(
qk = tl.where(mask_block, qk, float("-inf"))

if IS_CAUSAL:
causal = offs_n[None, :] > seq_pos[:, None]
# Absolute causal-offset: a query row's KV position is
# (kv_len - Lq) + seq_pos, correct for chunked prefill (Lq < kv_len).
# For the square is_causal case (kv_len == Lq) it reduces to
# offs_n > seq_pos. This lets a caller that guarantees a standard
# causal mask skip the materialized mask read entirely.
causal = offs_n[None, :] > (kv_len - Lq) + seq_pos[:, None]
qk = tl.where(causal, float("-inf"), qk)

qk = tl.where(kv_valid[None, :], qk, float("-inf"))
Expand Down Expand Up @@ -283,143 +288,25 @@ def _tq4_sdpa_fwd_kernel_body(


# ---------------------------------------------------------------------------
# Autotuned kernel wrappers (M64 and M32)
# Autotuned prefill kernel (single, no-spill)
# ---------------------------------------------------------------------------


@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3),
],
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
)
@triton.jit
def _tq4_sdpa_fwd_kernel_m64(
Q_ptr,
KP_ptr,
KN_ptr,
VP_ptr,
VN_ptr,
LUT_hi_ptr,
LUT_lo_ptr,
Mask_ptr,
O_ptr,
KV_LEN_ptr,
B,
H_grid,
Lq,
Lk,
stride_qb,
stride_qh,
stride_qm,
stride_qd,
stride_kpb,
stride_kph,
stride_kpn,
stride_kpd,
stride_knb,
stride_knh,
stride_knn,
stride_vpb,
stride_vph,
stride_vpn,
stride_vpd,
stride_vnb,
stride_vnh,
stride_vnn,
stride_ob,
stride_oh,
stride_om,
stride_od,
stride_mb,
stride_mq,
stride_mk,
sm_scale: tl.float32,
HAS_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
HAS_KV_LEN: tl.constexpr,
MASK_IS_CAUSAL: tl.constexpr,
HEAD_DIM: tl.constexpr,
HALF_D: tl.constexpr,
NUM_GROUPS: tl.constexpr,
PACK_GQA: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
_tq4_sdpa_fwd_kernel_body(
Q_ptr,
KP_ptr,
KN_ptr,
VP_ptr,
VN_ptr,
LUT_hi_ptr,
LUT_lo_ptr,
Mask_ptr,
O_ptr,
KV_LEN_ptr,
B,
H_grid,
Lq,
Lk,
stride_qb,
stride_qh,
stride_qm,
stride_qd,
stride_kpb,
stride_kph,
stride_kpn,
stride_kpd,
stride_knb,
stride_knh,
stride_knn,
stride_vpb,
stride_vph,
stride_vpn,
stride_vpd,
stride_vnb,
stride_vnh,
stride_vnn,
stride_ob,
stride_oh,
stride_om,
stride_od,
stride_mb,
stride_mq,
stride_mk,
sm_scale,
HAS_MASK=HAS_MASK,
IS_CAUSAL=IS_CAUSAL,
HAS_KV_LEN=HAS_KV_LEN,
MASK_IS_CAUSAL=MASK_IS_CAUSAL,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
HEAD_DIM=HEAD_DIM,
HALF_D=HALF_D,
NUM_GROUPS=NUM_GROUPS,
PACK_GQA=PACK_GQA,
)


@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3),
# Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
# correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=4),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=8, num_stages=3),
],
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
)
@triton.jit
def _tq4_sdpa_fwd_kernel_m32(
def _tq4_sdpa_prefill_kernel(
Q_ptr,
KP_ptr,
KN_ptr,
Expand Down Expand Up @@ -570,15 +457,7 @@ def _launch_tq4_kernel(
def grid(meta):
return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid)

total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid)
threshold = 4 * 84
kernel = (
_tq4_sdpa_fwd_kernel_m32
if total_ctas_m64 < threshold
else _tq4_sdpa_fwd_kernel_m64
)

wrap_triton(kernel)[grid](
wrap_triton(_tq4_sdpa_prefill_kernel)[grid](
q_rot,
k_packed,
k_norms,
Expand Down Expand Up @@ -845,6 +724,19 @@ def tq4_sdpa(
pack_gqa,
)
else:
# Prefill path (N_Q > 1, plus the rare N_Q==1 && N_KV<256 fallthrough).
# When the caller guarantees a standard causal mask AND kv_len is known
# (MASK_IS_CAUSAL), use the kernel's analytic absolute causal-offset and
# skip loading the materialized mask — numerically identical, no mask HBM
# traffic. Causal is then applied via IS_CAUSAL (which also drives the
# per-tile loop-end clamp), so MASK_IS_CAUSAL is passed False to the
# launcher. Otherwise honor the explicit mask / is_causal as-is.
if MASK_IS_CAUSAL:
prefill_has_mask = False
prefill_is_causal = True
else:
prefill_has_mask = HAS_MASK
prefill_is_causal = is_causal
_launch_tq4_kernel(
q_rot,
k_packed,
Expand All @@ -863,13 +755,13 @@ def tq4_sdpa(
N_KV,
D,
sm_scale,
HAS_MASK,
prefill_has_mask,
HAS_KV_LEN,
MASK_IS_CAUSAL,
False,
stride_mb,
stride_mq,
stride_mk,
is_causal,
prefill_is_causal,
num_groups,
pack_gqa,
)
Expand All @@ -889,17 +781,14 @@ def tq4_sdpa(

@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2),
# Extra BLOCK_N in {32,64} configs for smaller-SMEM GPUs (e.g. RTX 5090);
# correctness-safe (cos~1.0), never BLOCK_N=16 (numerically wrong).
triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=3),
],
key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK", "PACK_GQA"],
)
Expand Down
10 changes: 5 additions & 5 deletions examples/models/gemma4_31b/cuda_source_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def _turboquant_attention_forward(

Mirrors the default forward up to (and including) RoPE; only the
cache update and SDPA call differ.

NOTE: ``attn_mask`` is unused here and will be reconstucted in
the kernel to save data transfer, but is passed to the default forward
"""
B, T, _ = x.shape

Expand Down Expand Up @@ -94,9 +97,6 @@ def _turboquant_attention_forward(
# step (catastrophic at 128k: ~2.7 tok/s decode vs ~37+ when bounded).
kv_len = input_pos[0] + input_pos.shape[0]

# ``scale=self.scaling`` (= 1.0 for Gemma 4) — overrides tq4_sdpa's
# default ``1/sqrt(D)`` because Gemma's QK-norm has absorbed the
# 1/sqrt(d) factor into trained weights.
y = torch.ops.triton.tq4_sdpa(
q,
k_packed,
Expand All @@ -105,8 +105,8 @@ def _turboquant_attention_forward(
v_norms,
self.kv_cache.centroids,
self.kv_cache.rotation,
attn_mask,
False, # is_causal: attn_mask already encodes causal masking
None, # reconstuct attention mask in the kernel to save data transfer
False, # is_causal: needs L_q==L_kv; causal comes from mask_is_causal
self.scaling,
kv_len,
True, # mask_is_causal: Gemma full-attention mask is standard causal
Expand Down
Loading