Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 33 additions & 57 deletions backends/cuda/triton/kernels/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,6 @@ def _sdpa_decode_splitk_kernel(
K_ptr,
V_ptr,
O_partial_ptr,
M_partial_ptr,
L_partial_ptr,
Mask_ptr,
B,
Expand All @@ -1102,13 +1101,14 @@ def _sdpa_decode_splitk_kernel(
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_mb,
stride_mq,
stride_mk,
sm_scale: tl.float32,
phi: tl.float32,
chunk_size,
HAS_MASK: tl.constexpr,
BLOCK_N: tl.constexpr,
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def _sdpa_decode_splitk_kernel(
)
q = tl.load(q_ptrs, mask=g_valid[:, None], other=0.0).to(tl.bfloat16)

m_i = tl.full([BLOCK_G], -float("inf"), dtype=tl.float32)
# FlashDecoding++ async softmax: use unified max phi instead of tracking m_i
l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
acc = tl.zeros([BLOCK_G, HEAD_DIM], dtype=tl.float32)

Expand Down Expand Up @@ -1175,15 +1175,10 @@ def _sdpa_decode_splitk_kernel(
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
)

# Online softmax update
m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
safe_diff = tl.where(
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
)
# FlashDecoding++ async softmax: subtract unified phi instead of local max
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digantdesai here we replace the online softmax with async softmax by using a unified phi.

safe_diff = tl.where(qk > -float("inf"), qk - phi, -float("inf"))
p_f32 = tl.exp(safe_diff).to(tl.float32)
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
alpha = tl.exp(safe_alpha_diff).to(tl.float32)

v_ptrs = V_ptr + (
b * stride_vb
Expand All @@ -1194,9 +1189,8 @@ def _sdpa_decode_splitk_kernel(
v = tl.load(v_ptrs, mask=n_valid[:, None], other=0.0).to(tl.bfloat16)

p_bf16 = p_f32.to(tl.bfloat16)
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i * alpha + l_ij).to(tl.float32)
m_i = m_ij
acc = (acc + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i + l_ij).to(tl.float32)

# Store partial results for valid groups only
h_q_all = h_kv * NUM_GROUPS + offs_g # [BLOCK_G]
Expand All @@ -1208,31 +1202,25 @@ def _sdpa_decode_splitk_kernel(
)
tl.store(o_ptrs, acc, mask=g_valid[:, None])

ml_ptrs = M_partial_ptr + (
split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
)
tl.store(ml_ptrs, m_i, mask=g_valid)

ll_ptrs = L_partial_ptr + (
split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
split_id * stride_lp_s + b * stride_lp_b + h_q_all * stride_lp_h
)
tl.store(ll_ptrs, l_i, mask=g_valid)


@triton.jit
def _sdpa_decode_reduce_kernel(
O_partial_ptr,
M_partial_ptr,
L_partial_ptr,
O_ptr,
num_splits,
stride_op_s,
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_ob,
stride_oh,
stride_om,
Expand All @@ -1242,40 +1230,25 @@ def _sdpa_decode_reduce_kernel(
pid = tl.program_id(axis=0)
offs_d = tl.arange(0, HEAD_DIM)

# pid indexes into flattened (B, H_q). Partial buffers are allocated
# contiguous in _launch_decode_splitk, so pid * stride_*_h is valid.
# Find global max across all splits
m_global = tl.full([1], -float("inf"), dtype=tl.float32)
for s in tl.range(0, num_splits):
m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
m_s = tl.load(m_ptr)
m_global = tl.maximum(m_global, m_s)

# Accumulate rescaled outputs
# FlashDecoding++ async softmax: no rescaling needed, just sum partials
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
l_global = tl.zeros([1], dtype=tl.float32)

for s in tl.range(0, num_splits):
m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
l_ptr = L_partial_ptr + s * stride_mp_s + pid * stride_mp_h
l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h
o_ptrs = O_partial_ptr + (
s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d
)

m_s = tl.load(m_ptr)
l_s = tl.load(l_ptr)
o_s = tl.load(o_ptrs)

safe_diff = tl.where(m_global > -float("inf"), m_s - m_global, 0.0)
scale = tl.exp(safe_diff).to(tl.float32)
acc += o_s * scale
l_global += l_s * scale
acc += o_s
l_global += l_s

inv_l = tl.where(l_global > 0, 1.0 / l_global, 0.0)
acc = acc * inv_l

# pid = b*H_q + h_q. For contiguous output [B, H_q, 1, D] with L_q=1,
# stride_ob == H_q * stride_oh, so pid * stride_oh is correct.
# This relies on `out` being freshly allocated and contiguous.
o_out_ptrs = O_ptr + pid * stride_oh + offs_d * stride_od
tl.store(o_out_ptrs, acc.to(tl.bfloat16))

Expand All @@ -1297,16 +1270,14 @@ def _launch_decode_splitk(
stride_mq: int,
stride_mk: int,
num_groups: int,
phi: float,
) -> None:
num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)
chunk_size = triton.cdiv(L_kv, num_splits)

O_partial = torch.empty(
(num_splits, B, H_q, D), device=query.device, dtype=torch.float32
)
M_partial = torch.full(
(num_splits, B, H_q), -float("inf"), device=query.device, dtype=torch.float32
)
L_partial = torch.zeros(
(num_splits, B, H_q), device=query.device, dtype=torch.float32
)
Expand All @@ -1316,15 +1287,14 @@ def _launch_decode_splitk(
stride_vb, stride_vh, stride_vn, stride_vd = value.stride()
stride_ob, stride_oh, stride_om, stride_od = out.stride()
stride_op_s, stride_op_b, stride_op_h, stride_op_d = O_partial.stride()
stride_mp_s, stride_mp_b, stride_mp_h = M_partial.stride()
stride_lp_s, stride_lp_b, stride_lp_h = L_partial.stride()

grid_split = (num_splits, B * H_kv)
wrap_triton(_sdpa_decode_splitk_kernel)[grid_split](
query,
key,
value,
O_partial,
M_partial,
L_partial,
Mask_ptr if HAS_MASK else 0,
B,
Expand All @@ -1346,13 +1316,14 @@ def _launch_decode_splitk(
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_mb,
stride_mq,
stride_mk,
sm_scale,
phi,
chunk_size,
HAS_MASK=HAS_MASK,
HEAD_DIM=D,
Expand All @@ -1363,17 +1334,16 @@ def _launch_decode_splitk(
grid_reduce = (B * H_q,)
wrap_triton(_sdpa_decode_reduce_kernel)[grid_reduce](
O_partial,
M_partial,
L_partial,
out,
num_splits,
stride_op_s,
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_ob,
stride_oh,
stride_om,
Expand All @@ -1394,9 +1364,13 @@ def sdpa_decode_splitk(
is_causal: bool = False,
scale: float = 0.0,
enable_gqa: bool = False,
phi: float = 5.0,
) -> torch.Tensor:
"""Split-K flash-decoding SDPA for L_q=1 (decode step).

Uses FlashDecoding++ async softmax with unified maximum value (phi)
to eliminate per-split max tracking and cross-split rescaling.

Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
enable_gqa is accepted but ignored — GQA is handled natively via
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
Expand Down Expand Up @@ -1452,6 +1426,7 @@ def sdpa_decode_splitk(
stride_mq,
stride_mk,
num_groups,
phi,
)
return out

Expand All @@ -1466,6 +1441,7 @@ def _sdpa_decode_splitk_abstract(
is_causal: bool = False,
scale: float = 0.0,
enable_gqa: bool = False,
phi: float = 5.0,
) -> torch.Tensor:
assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype"
B, H_q, L_q, D = query.shape
Expand Down
Loading