Skip to content

Commit 1014985

Browse files
Gasoonjiagasoonjia
andauthored
Implement FlashDecoding++ async softmax for split-K SDPA (#18867)
Replace online softmax (per-tile max tracking + cross-split rescaling) with a unified maximum value (phi=5.0) approach from FlashDecoding++. Key changes: - Split kernel: subtract fixed phi instead of tracking running max m_i, eliminating alpha rescaling between tiles - Reduce kernel: simple summation of partial outputs instead of max-aware weighted combination; removes M_partial buffer - ~12.9% average kernel-level speedup (6.8%-20.1% range) by saving HBM bandwidth (no M_partial reads/writes) and reducing ALU ops The unified phi works because exp(qk - phi) is numerically stable for typical attention score ranges, and the fixed constant allows all splits to compute independently without synchronization. Also used KernelAgent(https://github.com/meta-pytorch/KernelAgent) to further optimized the kernel. Config | absolute perf (gain) -- | -- p=128 d=128 | 149.9 (+9.7) p=128 d=512 | 153.7 (+10.3) p=256 d=128 | 149.6 (+9.5) p=256 d=512 | 153.5 (+11.6) p=512 d=128 | 149.1 (+9.1) p=512 d=512 | 152.9 (+10.1) p=1024 d=128 | 148.6 (+8.9) p=1024 d=512 | 153.4 (+10.2) p=2048 d=128 | 148.3 (+9.6) p=2048 d=512 | 153.1 (+10.2) Average | 151.2 (+9.9) cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell --------- Co-authored-by: gasoonjia <gasoonjia@fb.com>
1 parent c1731fd commit 1014985

1 file changed

Lines changed: 33 additions & 57 deletions

File tree

  • backends/cuda/triton/kernels

backends/cuda/triton/kernels/sdpa.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,6 @@ def _sdpa_decode_splitk_kernel(
10801080
K_ptr,
10811081
V_ptr,
10821082
O_partial_ptr,
1083-
M_partial_ptr,
10841083
L_partial_ptr,
10851084
Mask_ptr,
10861085
B,
@@ -1102,13 +1101,14 @@ def _sdpa_decode_splitk_kernel(
11021101
stride_op_b,
11031102
stride_op_h,
11041103
stride_op_d,
1105-
stride_mp_s,
1106-
stride_mp_b,
1107-
stride_mp_h,
1104+
stride_lp_s,
1105+
stride_lp_b,
1106+
stride_lp_h,
11081107
stride_mb,
11091108
stride_mq,
11101109
stride_mk,
11111110
sm_scale: tl.float32,
1111+
phi: tl.float32,
11121112
chunk_size,
11131113
HAS_MASK: tl.constexpr,
11141114
BLOCK_N: tl.constexpr,
@@ -1138,7 +1138,7 @@ def _sdpa_decode_splitk_kernel(
11381138
)
11391139
q = tl.load(q_ptrs, mask=g_valid[:, None], other=0.0).to(tl.bfloat16)
11401140

1141-
m_i = tl.full([BLOCK_G], -float("inf"), dtype=tl.float32)
1141+
# FlashDecoding++ async softmax: use unified max phi instead of tracking m_i
11421142
l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
11431143
acc = tl.zeros([BLOCK_G, HEAD_DIM], dtype=tl.float32)
11441144

@@ -1175,15 +1175,10 @@ def _sdpa_decode_splitk_kernel(
11751175
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
11761176
)
11771177

1178-
# Online softmax update
1179-
m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
1180-
safe_diff = tl.where(
1181-
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
1182-
)
1178+
# FlashDecoding++ async softmax: subtract unified phi instead of local max
1179+
safe_diff = tl.where(qk > -float("inf"), qk - phi, -float("inf"))
11831180
p_f32 = tl.exp(safe_diff).to(tl.float32)
11841181
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
1185-
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
1186-
alpha = tl.exp(safe_alpha_diff).to(tl.float32)
11871182

11881183
v_ptrs = V_ptr + (
11891184
b * stride_vb
@@ -1194,9 +1189,8 @@ def _sdpa_decode_splitk_kernel(
11941189
v = tl.load(v_ptrs, mask=n_valid[:, None], other=0.0).to(tl.bfloat16)
11951190

11961191
p_bf16 = p_f32.to(tl.bfloat16)
1197-
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
1198-
l_i = (l_i * alpha + l_ij).to(tl.float32)
1199-
m_i = m_ij
1192+
acc = (acc + tl.dot(p_bf16, v)).to(tl.float32)
1193+
l_i = (l_i + l_ij).to(tl.float32)
12001194

12011195
# Store partial results for valid groups only
12021196
h_q_all = h_kv * NUM_GROUPS + offs_g # [BLOCK_G]
@@ -1208,31 +1202,25 @@ def _sdpa_decode_splitk_kernel(
12081202
)
12091203
tl.store(o_ptrs, acc, mask=g_valid[:, None])
12101204

1211-
ml_ptrs = M_partial_ptr + (
1212-
split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
1213-
)
1214-
tl.store(ml_ptrs, m_i, mask=g_valid)
1215-
12161205
ll_ptrs = L_partial_ptr + (
1217-
split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
1206+
split_id * stride_lp_s + b * stride_lp_b + h_q_all * stride_lp_h
12181207
)
12191208
tl.store(ll_ptrs, l_i, mask=g_valid)
12201209

12211210

12221211
@triton.jit
12231212
def _sdpa_decode_reduce_kernel(
12241213
O_partial_ptr,
1225-
M_partial_ptr,
12261214
L_partial_ptr,
12271215
O_ptr,
12281216
num_splits,
12291217
stride_op_s,
12301218
stride_op_b,
12311219
stride_op_h,
12321220
stride_op_d,
1233-
stride_mp_s,
1234-
stride_mp_b,
1235-
stride_mp_h,
1221+
stride_lp_s,
1222+
stride_lp_b,
1223+
stride_lp_h,
12361224
stride_ob,
12371225
stride_oh,
12381226
stride_om,
@@ -1242,40 +1230,25 @@ def _sdpa_decode_reduce_kernel(
12421230
pid = tl.program_id(axis=0)
12431231
offs_d = tl.arange(0, HEAD_DIM)
12441232

1245-
# pid indexes into flattened (B, H_q). Partial buffers are allocated
1246-
# contiguous in _launch_decode_splitk, so pid * stride_*_h is valid.
1247-
# Find global max across all splits
1248-
m_global = tl.full([1], -float("inf"), dtype=tl.float32)
1249-
for s in tl.range(0, num_splits):
1250-
m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1251-
m_s = tl.load(m_ptr)
1252-
m_global = tl.maximum(m_global, m_s)
1253-
1254-
# Accumulate rescaled outputs
1233+
# FlashDecoding++ async softmax: no rescaling needed, just sum partials
12551234
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
12561235
l_global = tl.zeros([1], dtype=tl.float32)
1236+
12571237
for s in tl.range(0, num_splits):
1258-
m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1259-
l_ptr = L_partial_ptr + s * stride_mp_s + pid * stride_mp_h
1238+
l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h
12601239
o_ptrs = O_partial_ptr + (
12611240
s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d
12621241
)
12631242

1264-
m_s = tl.load(m_ptr)
12651243
l_s = tl.load(l_ptr)
12661244
o_s = tl.load(o_ptrs)
12671245

1268-
safe_diff = tl.where(m_global > -float("inf"), m_s - m_global, 0.0)
1269-
scale = tl.exp(safe_diff).to(tl.float32)
1270-
acc += o_s * scale
1271-
l_global += l_s * scale
1246+
acc += o_s
1247+
l_global += l_s
12721248

12731249
inv_l = tl.where(l_global > 0, 1.0 / l_global, 0.0)
12741250
acc = acc * inv_l
12751251

1276-
# pid = b*H_q + h_q. For contiguous output [B, H_q, 1, D] with L_q=1,
1277-
# stride_ob == H_q * stride_oh, so pid * stride_oh is correct.
1278-
# This relies on `out` being freshly allocated and contiguous.
12791252
o_out_ptrs = O_ptr + pid * stride_oh + offs_d * stride_od
12801253
tl.store(o_out_ptrs, acc.to(tl.bfloat16))
12811254

@@ -1297,16 +1270,14 @@ def _launch_decode_splitk(
12971270
stride_mq: int,
12981271
stride_mk: int,
12991272
num_groups: int,
1273+
phi: float,
13001274
) -> None:
13011275
num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)
13021276
chunk_size = triton.cdiv(L_kv, num_splits)
13031277

13041278
O_partial = torch.empty(
13051279
(num_splits, B, H_q, D), device=query.device, dtype=torch.float32
13061280
)
1307-
M_partial = torch.full(
1308-
(num_splits, B, H_q), -float("inf"), device=query.device, dtype=torch.float32
1309-
)
13101281
L_partial = torch.zeros(
13111282
(num_splits, B, H_q), device=query.device, dtype=torch.float32
13121283
)
@@ -1316,15 +1287,14 @@ def _launch_decode_splitk(
13161287
stride_vb, stride_vh, stride_vn, stride_vd = value.stride()
13171288
stride_ob, stride_oh, stride_om, stride_od = out.stride()
13181289
stride_op_s, stride_op_b, stride_op_h, stride_op_d = O_partial.stride()
1319-
stride_mp_s, stride_mp_b, stride_mp_h = M_partial.stride()
1290+
stride_lp_s, stride_lp_b, stride_lp_h = L_partial.stride()
13201291

13211292
grid_split = (num_splits, B * H_kv)
13221293
wrap_triton(_sdpa_decode_splitk_kernel)[grid_split](
13231294
query,
13241295
key,
13251296
value,
13261297
O_partial,
1327-
M_partial,
13281298
L_partial,
13291299
Mask_ptr if HAS_MASK else 0,
13301300
B,
@@ -1346,13 +1316,14 @@ def _launch_decode_splitk(
13461316
stride_op_b,
13471317
stride_op_h,
13481318
stride_op_d,
1349-
stride_mp_s,
1350-
stride_mp_b,
1351-
stride_mp_h,
1319+
stride_lp_s,
1320+
stride_lp_b,
1321+
stride_lp_h,
13521322
stride_mb,
13531323
stride_mq,
13541324
stride_mk,
13551325
sm_scale,
1326+
phi,
13561327
chunk_size,
13571328
HAS_MASK=HAS_MASK,
13581329
HEAD_DIM=D,
@@ -1363,17 +1334,16 @@ def _launch_decode_splitk(
13631334
grid_reduce = (B * H_q,)
13641335
wrap_triton(_sdpa_decode_reduce_kernel)[grid_reduce](
13651336
O_partial,
1366-
M_partial,
13671337
L_partial,
13681338
out,
13691339
num_splits,
13701340
stride_op_s,
13711341
stride_op_b,
13721342
stride_op_h,
13731343
stride_op_d,
1374-
stride_mp_s,
1375-
stride_mp_b,
1376-
stride_mp_h,
1344+
stride_lp_s,
1345+
stride_lp_b,
1346+
stride_lp_h,
13771347
stride_ob,
13781348
stride_oh,
13791349
stride_om,
@@ -1394,9 +1364,13 @@ def sdpa_decode_splitk(
13941364
is_causal: bool = False,
13951365
scale: float = 0.0,
13961366
enable_gqa: bool = False,
1367+
phi: float = 5.0,
13971368
) -> torch.Tensor:
13981369
"""Split-K flash-decoding SDPA for L_q=1 (decode step).
13991370
1371+
Uses FlashDecoding++ async softmax with unified maximum value (phi)
1372+
to eliminate per-split max tracking and cross-split rescaling.
1373+
14001374
Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
14011375
enable_gqa is accepted but ignored — GQA is handled natively via
14021376
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
@@ -1452,6 +1426,7 @@ def sdpa_decode_splitk(
14521426
stride_mq,
14531427
stride_mk,
14541428
num_groups,
1429+
phi,
14551430
)
14561431
return out
14571432

@@ -1466,6 +1441,7 @@ def _sdpa_decode_splitk_abstract(
14661441
is_causal: bool = False,
14671442
scale: float = 0.0,
14681443
enable_gqa: bool = False,
1444+
phi: float = 5.0,
14691445
) -> torch.Tensor:
14701446
assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype"
14711447
B, H_q, L_q, D = query.shape

0 commit comments

Comments
 (0)