Skip to content

Commit 39589ae

Browse files
committed
finetuned using KA
1 parent c93f8ae commit 39589ae

1 file changed

Lines changed: 32 additions & 55 deletions

File tree

  • backends/cuda/triton/kernels

backends/cuda/triton/kernels/sdpa.py

Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -97,26 +97,6 @@ def round_up(a, b):
9797
return nopack_eff < 0.9 * pack_eff
9898

9999

100-
def _compute_num_splits(L_kv: int, B: int, H_kv: int, device: torch.device) -> int:
101-
"""Compute optimal KV-split count for flash-decoding on A100 / RTX 4090.
102-
103-
Balances GPU occupancy against per-split work:
104-
* Targets >= 2 full SM waves (2 x SM-count CTAs) so the GPU stays
105-
saturated even with tail effects.
106-
* Enforces a minimum of 64 KV tokens per split to amortise
107-
kernel-launch and reduce overhead.
108-
* Caps at 128 splits to bound reduce-kernel cost.
109-
110-
A100 -> 108 SMs, RTX 4090 -> 128 SMs. The heuristic adapts to
111-
whatever GPU is present via ``torch.cuda.get_device_properties``.
112-
"""
113-
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
114-
ctas_per_split = max(B * H_kv, 1)
115-
target = max(triton.cdiv(sm_count * 2, ctas_per_split), 1)
116-
max_by_work = max(L_kv // 64, 1)
117-
return min(target, max_by_work, 128)
118-
119-
120100
def _validate_qkv_shapes(
121101
query: torch.Tensor,
122102
key: torch.Tensor,
@@ -1091,8 +1071,6 @@ def _sdpa_abstract(
10911071
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2),
10921072
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2),
10931073
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2),
1094-
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=3),
1095-
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=3),
10961074
],
10971075
key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK"],
10981076
)
@@ -1129,24 +1107,19 @@ def _sdpa_decode_splitk_kernel(
11291107
stride_mb,
11301108
stride_mq,
11311109
stride_mk,
1132-
sm_scale_log2: tl.float32,
1133-
phi_log2: tl.float32,
1110+
sm_scale: tl.float32,
1111+
phi: tl.float32,
11341112
chunk_size,
11351113
HAS_MASK: tl.constexpr,
11361114
BLOCK_N: tl.constexpr,
11371115
HEAD_DIM: tl.constexpr,
11381116
NUM_GROUPS: tl.constexpr,
11391117
BLOCK_G: tl.constexpr,
1140-
BATCH_ONE: tl.constexpr,
11411118
):
11421119
split_id = tl.program_id(axis=0)
11431120
pid_bh = tl.program_id(axis=1)
1144-
if BATCH_ONE:
1145-
b = 0
1146-
h_kv = pid_bh
1147-
else:
1148-
b = pid_bh // H_kv
1149-
h_kv = pid_bh % H_kv
1121+
b = pid_bh // H_kv
1122+
h_kv = pid_bh % H_kv
11501123

11511124
start_n = split_id * chunk_size
11521125
end_n = tl.minimum(start_n + chunk_size, Lk)
@@ -1163,11 +1136,9 @@ def _sdpa_decode_splitk_kernel(
11631136
+ 0 * stride_qm
11641137
+ offs_d[None, :] * stride_qd
11651138
)
1166-
q = tl.load(q_ptrs, mask=g_valid[:, None], other=0.0)
1167-
# Pre-scale Q so the inner loop avoids a per-element multiply on [G,N] QK
1168-
q = (q.to(tl.float32) * sm_scale_log2).to(tl.bfloat16)
1139+
q = tl.load(q_ptrs, mask=g_valid[:, None], other=0.0).to(tl.bfloat16)
11691140

1170-
# FlashDecoding++ async softmax with exp2: all scores in log2 space
1141+
# FlashDecoding++ async softmax: use unified max phi instead of tracking m_i
11711142
l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
11721143
acc = tl.zeros([BLOCK_G, HEAD_DIM], dtype=tl.float32)
11731144

@@ -1185,8 +1156,8 @@ def _sdpa_decode_splitk_kernel(
11851156
)
11861157
k = tl.load(k_ptrs, mask=n_valid[:, None], other=0.0).to(tl.bfloat16)
11871158

1188-
# QK: [BLOCK_G, BLOCK_N] — Q already scaled, result in log2 space
1189-
qk = tl.dot(q, tl.trans(k)).to(tl.float32)
1159+
# QK: [BLOCK_G, BLOCK_N]
1160+
qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32)
11901161

11911162
# Mask out-of-bounds KV positions
11921163
qk = tl.where(
@@ -1204,9 +1175,9 @@ def _sdpa_decode_splitk_kernel(
12041175
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
12051176
)
12061177

1207-
# FlashDecoding++ async softmax: exp2 maps to single PTX ex2 instruction
1208-
safe_diff = tl.where(qk > -float("inf"), qk - phi_log2, -float("inf"))
1209-
p_f32 = tl.math.exp2(safe_diff).to(tl.float32)
1178+
# FlashDecoding++ async softmax: subtract unified phi instead of local max
1179+
safe_diff = tl.where(qk > -float("inf"), qk - phi, -float("inf"))
1180+
p_f32 = tl.exp(safe_diff).to(tl.float32)
12101181
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
12111182

12121183
v_ptrs = V_ptr + (
@@ -1263,7 +1234,7 @@ def _sdpa_decode_reduce_kernel(
12631234
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
12641235
l_global = tl.zeros([1], dtype=tl.float32)
12651236

1266-
for s in tl.range(0, num_splits, num_stages=2):
1237+
for s in tl.range(0, num_splits):
12671238
l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h
12681239
o_ptrs = O_partial_ptr + (
12691240
s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d
@@ -1282,6 +1253,9 @@ def _sdpa_decode_reduce_kernel(
12821253
tl.store(o_out_ptrs, acc.to(tl.bfloat16))
12831254

12841255

1256+
_splitk_buf_cache: dict = {}
1257+
1258+
12851259
def _launch_decode_splitk(
12861260
query: torch.Tensor,
12871261
key: torch.Tensor,
@@ -1301,19 +1275,23 @@ def _launch_decode_splitk(
13011275
num_groups: int,
13021276
phi: float,
13031277
) -> None:
1304-
num_splits = _compute_num_splits(L_kv, B, H_kv, query.device)
1278+
num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)
13051279
chunk_size = triton.cdiv(L_kv, num_splits)
13061280

1307-
_LOG2E = 1.4426950408889634
1308-
sm_scale_log2 = sm_scale * _LOG2E
1309-
phi_log2 = phi * _LOG2E
1310-
1311-
O_partial = torch.empty(
1312-
(num_splits, B, H_q, D), device=query.device, dtype=torch.float32
1313-
)
1314-
L_partial = torch.zeros(
1315-
(num_splits, B, H_q), device=query.device, dtype=torch.float32
1316-
)
1281+
# Cache partial buffers to avoid CUDA allocator overhead per call.
1282+
# The split kernel fully writes every entry before the reduce kernel
1283+
# reads, so stale data from a previous call is harmless.
1284+
buf_key = (num_splits, B, H_q, D, query.device.index)
1285+
bufs = _splitk_buf_cache.get(buf_key)
1286+
if bufs is None:
1287+
bufs = (
1288+
torch.empty(
1289+
(num_splits, B, H_q, D), device=query.device, dtype=torch.float32
1290+
),
1291+
torch.empty((num_splits, B, H_q), device=query.device, dtype=torch.float32),
1292+
)
1293+
_splitk_buf_cache[buf_key] = bufs
1294+
O_partial, L_partial = bufs
13171295

13181296
stride_qb, stride_qh, stride_qm, stride_qd = query.stride()
13191297
stride_kb, stride_kh, stride_kn, stride_kd = key.stride()
@@ -1355,14 +1333,13 @@ def _launch_decode_splitk(
13551333
stride_mb,
13561334
stride_mq,
13571335
stride_mk,
1358-
sm_scale_log2,
1359-
phi_log2,
1336+
sm_scale,
1337+
phi,
13601338
chunk_size,
13611339
HAS_MASK=HAS_MASK,
13621340
HEAD_DIM=D,
13631341
NUM_GROUPS=num_groups,
13641342
BLOCK_G=_next_power_of_2_unclamped(num_groups),
1365-
BATCH_ONE=B == 1,
13661343
)
13671344

13681345
grid_reduce = (B * H_q,)

0 commit comments

Comments
 (0)