Skip to content

Commit ecce53a

Browse files
mla cuda kernels for GLM4.7
1 parent 76bb307 commit ecce53a

30 files changed

Lines changed: 2855 additions & 37 deletions

bench_mla_opt/diag_mask.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""Diagnostic: isolate the cost of (load-mask, where-mask, loop-peel) in the
2+
single-pass MLA kernel, in the benchmark regime (seqlen % BLOCK_N == 0, so the
3+
fully-unmasked path is numerically correct). Run under CUDA_VISIBLE_DEVICES.
4+
"""
5+
import statistics
6+
import torch
7+
import triton
8+
import triton.language as tl
9+
from bench_triton_mla import make_inputs, KV_DIM, KV_LORA
10+
11+
12+
@triton.jit
13+
def _diag_kernel(
14+
Q, K_Buffer, V_Buffer, sm_scale, Seqlens, Block_Table, O,
15+
stride_qbs, stride_qh, stride_buf_kbs, stride_buf_vbs,
16+
stride_obs, stride_oh, stride_bt_b,
17+
q_head_num: tl.constexpr, BLOCK_SIZE: tl.constexpr,
18+
BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, BLOCK_DV: tl.constexpr,
19+
BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr,
20+
USE_LOAD_MASK: tl.constexpr, USE_WHERE: tl.constexpr, PEEL: tl.constexpr,
21+
WITH_TAIL: tl.constexpr,
22+
):
23+
cur_batch = tl.program_id(0)
24+
cur_head_id = tl.program_id(1)
25+
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
26+
mask_h = cur_head < q_head_num
27+
28+
offs_d = tl.arange(0, BLOCK_DMODEL)
29+
offs_dv = tl.arange(0, BLOCK_DV)
30+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
31+
seqlen = tl.load(Seqlens + cur_batch)
32+
33+
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
34+
q = tl.load(Q + offs_q, mask=mask_h[:, None], other=0.0)
35+
off_qpe = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
36+
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0)
37+
38+
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
39+
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
40+
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
41+
bt_base = Block_Table + cur_batch * stride_bt_b
42+
43+
if PEEL:
44+
loop_end = (seqlen // BLOCK_N) * BLOCK_N
45+
else:
46+
loop_end = seqlen
47+
48+
for start_n in range(0, loop_end, BLOCK_N):
49+
offs_n = start_n + tl.arange(0, BLOCK_N)
50+
valid = offs_n < seqlen
51+
if USE_LOAD_MASK:
52+
page = tl.load(bt_base + (offs_n // BLOCK_SIZE), mask=valid, other=0)
53+
else:
54+
page = tl.load(bt_base + (offs_n // BLOCK_SIZE))
55+
kv_loc = page * BLOCK_SIZE + (offs_n % BLOCK_SIZE)
56+
if USE_LOAD_MASK:
57+
k = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_d[:, None],
58+
mask=valid[None, :], other=0.0)
59+
kpe = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_dpe[:, None],
60+
mask=valid[None, :], other=0.0)
61+
v = tl.load(V_Buffer + kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :],
62+
mask=valid[:, None], other=0.0)
63+
else:
64+
k = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_d[:, None])
65+
kpe = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_dpe[:, None])
66+
v = tl.load(V_Buffer + kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :])
67+
qk = tl.dot(q, k.to(q.dtype))
68+
qk += tl.dot(qpe, kpe.to(qpe.dtype))
69+
qk *= sm_scale
70+
if USE_WHERE:
71+
qk = tl.where(valid[None, :], qk, float("-inf"))
72+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
73+
re_scale = tl.exp(e_max - n_e_max)
74+
p = tl.exp(qk - n_e_max[:, None])
75+
acc *= re_scale[:, None]
76+
acc += tl.dot(p.to(v.dtype), v)
77+
e_sum = e_sum * re_scale + tl.sum(p, 1)
78+
e_max = n_e_max
79+
80+
if WITH_TAIL:
81+
# separate duplicated masked tail block (mirrors _fwd_blocktable_mla_kernel_opt)
82+
if loop_end < seqlen:
83+
offs_n = loop_end + tl.arange(0, BLOCK_N)
84+
valid = offs_n < seqlen
85+
page = tl.load(bt_base + (offs_n // BLOCK_SIZE), mask=valid, other=0)
86+
kv_loc = page * BLOCK_SIZE + (offs_n % BLOCK_SIZE)
87+
k = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_d[:, None],
88+
mask=valid[None, :], other=0.0)
89+
qk = tl.dot(q, k.to(q.dtype))
90+
kpe = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_dpe[:, None],
91+
mask=valid[None, :], other=0.0)
92+
qk += tl.dot(qpe, kpe.to(qpe.dtype))
93+
qk *= sm_scale
94+
qk = tl.where(valid[None, :], qk, float("-inf"))
95+
v = tl.load(V_Buffer + kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :],
96+
mask=valid[:, None], other=0.0)
97+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
98+
re_scale = tl.exp(e_max - n_e_max)
99+
p = tl.exp(qk - n_e_max[:, None])
100+
acc *= re_scale[:, None]
101+
acc += tl.dot(p.to(v.dtype), v)
102+
e_sum = e_sum * re_scale + tl.sum(p, 1)
103+
e_max = n_e_max
104+
105+
offs_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_dv[None, :]
106+
tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_h[:, None])
107+
108+
109+
def run(q, latent, bt, sl, sm, blk, o, BLOCK_H, lm, wh, peel, tail=False, num_warps=8):
110+
bs, H, _ = q.shape
111+
grid = (bs, triton.cdiv(H, BLOCK_H))
112+
_diag_kernel[grid](
113+
q, latent, latent, sm, sl, bt, o,
114+
q.stride(0), q.stride(1), latent.stride(0), latent.stride(0),
115+
o.stride(0), o.stride(1), bt.stride(0),
116+
q_head_num=H, BLOCK_SIZE=blk, BLOCK_DMODEL=512, BLOCK_DPE=64,
117+
BLOCK_DV=512, BLOCK_N=blk, BLOCK_H=BLOCK_H,
118+
USE_LOAD_MASK=lm, USE_WHERE=wh, PEEL=peel, WITH_TAIL=tail, num_warps=num_warps,
119+
)
120+
return o
121+
122+
123+
def bench(fn, q, latent, bt, sl, o, blk, iters=50, warmup=20):
124+
sm = 1.0 / (KV_DIM ** 0.5)
125+
for _ in range(warmup):
126+
fn(q, latent, bt, sl, sm, blk, o)
127+
torch.cuda.synchronize()
128+
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
129+
s.record()
130+
for _ in range(iters):
131+
fn(q, latent, bt, sl, sm, blk, o)
132+
e.record(); torch.cuda.synchronize()
133+
return s.elapsed_time(e) / iters
134+
135+
136+
def gbps(q, latent, bt, sl, o, blk, BLOCK_H, lm, wh, peel, tail=False):
137+
bs, H, _ = q.shape
138+
tok = int(sl[0])
139+
best = []
140+
for _ in range(3):
141+
ms = bench(lambda *a: run(*a, BLOCK_H, lm, wh, peel, tail), q, latent, bt, sl, o, blk)
142+
best.append(bs * tok * KV_DIM * 2 / (ms * 1e-3) / 1e9)
143+
return statistics.median(best)
144+
145+
146+
if __name__ == "__main__":
147+
dev, dt = "cuda", torch.bfloat16
148+
bs = 128
149+
configs = [
150+
# name, load_mask, where, peel, with_tail
151+
("baseline single-loop lm+wh ", True, True, False, False),
152+
("unmasked single-loop ", False, False, False, False),
153+
("peel, NO tail block ", False, False, True, False),
154+
("peel + SEPARATE tail block ", False, False, True, True),
155+
]
156+
for H, BLOCK_H in [(16, 16), (20, 16), (20, 32)]:
157+
print(f"\n=== H={H} BLOCK_H={BLOCK_H} bs={bs} (median of 3 GB/s) ===")
158+
hdr = f"{'config':<30}" + "".join(f"{f'b{b}/t{t}':>11}" for b in (16, 64) for t in (1024, 4096))
159+
print(hdr)
160+
for name, lm, wh, peel, tail in configs:
161+
row = f"{name:<30}"
162+
for blk in (16, 64):
163+
for tok in (1024, 4096):
164+
q, latent, bt, sl, o = make_inputs(bs, H, blk, tok, dev, dt)
165+
g = gbps(q, latent, bt, sl, o, blk, BLOCK_H, lm, wh, peel, tail)
166+
row += f"{g:>11.0f}"
167+
print(row)

bench_triton_mla.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Microbenchmark for the vortex block-table MLA decode kernel(s).
2+
3+
Pure kernel efficiency — NO sglang / RULER. Builds synthetic decode inputs
4+
(q, fused latent pool, sparse block_table, seqlens), times each kernel variant
5+
with CUDA events, and reports achieved HBM bandwidth.
6+
7+
Bandwidth model (decode is KV-read bound): the kernel must read, per request,
8+
its `selected_tokens` fused-latent rows (kv_lora_rank + qk_rope = 576 bf16).
9+
Pages are distinct + scattered across requests (no L2 reuse), so
10+
11+
KV_bytes = bs * selected_tokens * 576 * dtype_size
12+
achieved_BW = KV_bytes / kernel_time
13+
14+
is the meaningful "useful" HBM bandwidth (a perfect kernel reads each latent row
15+
once per request). B200 HBM3e peak ~= 8 TB/s.
16+
17+
export HF_HOME=/raid/catalyst/models/
18+
CUDA_VISIBLE_DEVICES=<gpu> python marks/mla/bench_triton_mla.py
19+
"""
20+
import argparse
21+
import json
22+
import torch
23+
24+
from vortex_torch.engine.sgl.attention_backend.triton_mla_kernel import (
25+
KERNELS, # {name: fn(q, latent, block_table, seqlens, sm_scale, block_size, kv_lora_rank, o)}
26+
)
27+
28+
KV_DIM = 576
29+
KV_LORA = 512
30+
PEAK_BW_GBPS = 8000.0 # B200 HBM3e ~8 TB/s
31+
32+
33+
def make_inputs(bs, num_heads, block_size, selected_tokens, device, dtype):
34+
assert selected_tokens % block_size == 0
35+
n_blocks = selected_tokens // block_size
36+
num_pages = bs * n_blocks # distinct page per (req, slot)
37+
latent = torch.randn(num_pages * block_size, KV_DIM, device=device, dtype=dtype)
38+
pages = torch.randperm(num_pages, device=device, dtype=torch.int32)
39+
block_table = pages.view(bs, n_blocks).contiguous()
40+
seqlens = torch.full((bs,), selected_tokens, device=device, dtype=torch.int32)
41+
q = torch.randn(bs, num_heads, KV_DIM, device=device, dtype=dtype)
42+
o = torch.empty(bs, num_heads, KV_LORA, device=device, dtype=dtype)
43+
return q, latent, block_table, seqlens, o
44+
45+
46+
def torch_reference(q, latent, block_table, seqlens, sm_scale, block_size):
47+
"""Dense block-sparse MLA attention reference (fp32) for correctness."""
48+
bs, H, _ = q.shape
49+
out = torch.empty(bs, H, KV_LORA, device=q.device, dtype=torch.float32)
50+
qf = q.float()
51+
lf = latent.float()
52+
for b in range(bs):
53+
sl = int(seqlens[b])
54+
nb = (sl + block_size - 1) // block_size
55+
rows = []
56+
for j in range(nb):
57+
page = int(block_table[b, j])
58+
rows.append(torch.arange(page * block_size, page * block_size + block_size,
59+
device=q.device))
60+
slots = torch.cat(rows)[:sl]
61+
k = lf[slots] # [sl, 576]
62+
scores = (qf[b] @ k.t()) * sm_scale # [H, sl]
63+
p = torch.softmax(scores, dim=-1)
64+
out[b] = p @ k[:, :KV_LORA] # [H, 512]
65+
return out
66+
67+
68+
def bench_one(fn, q, latent, block_table, seqlens, o, block_size, iters=50, warmup=20):
69+
sm_scale = 1.0 / (KV_DIM ** 0.5)
70+
for _ in range(warmup):
71+
fn(q, latent, block_table, seqlens, sm_scale, block_size, KV_LORA, o)
72+
torch.cuda.synchronize()
73+
start = torch.cuda.Event(enable_timing=True)
74+
end = torch.cuda.Event(enable_timing=True)
75+
start.record()
76+
for _ in range(iters):
77+
fn(q, latent, block_table, seqlens, sm_scale, block_size, KV_LORA, o)
78+
end.record()
79+
torch.cuda.synchronize()
80+
return start.elapsed_time(end) / iters
81+
82+
83+
def main():
84+
ap = argparse.ArgumentParser()
85+
ap.add_argument("--bs", type=int, default=128)
86+
ap.add_argument("--heads", type=int, default=16)
87+
ap.add_argument("--dtype", default="bf16")
88+
ap.add_argument("--kernels", default="all")
89+
ap.add_argument("--out", default="marks/mla/bench_results.jsonl")
90+
args = ap.parse_args()
91+
92+
device = "cuda"
93+
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[args.dtype]
94+
names = list(KERNELS) if args.kernels == "all" else args.kernels.split(",")
95+
96+
blocks = [16, 32, 64]
97+
token_budgets = [512, 1024, 2048, 4096]
98+
sm_scale = 1.0 / (KV_DIM ** 0.5)
99+
100+
print(f"B200 MLA decode kernel bench | bs={args.bs} heads={args.heads} "
101+
f"dtype={args.dtype} | peak~{PEAK_BW_GBPS/1000:.0f}TB/s")
102+
print(f"{'kernel':<14}{'blk':>4}{'sel_tok':>8}{'ms':>9}{'GB/s':>9}{'%peak':>7}{'maxerr':>10}")
103+
104+
results = []
105+
# correctness on a small case once per kernel
106+
for name in names:
107+
fn = KERNELS[name]
108+
# --- correctness (small) ---
109+
q, latent, bt, sl, o = make_inputs(4, args.heads, 32, 256, device, dtype)
110+
ref = torch_reference(q, latent, bt, sl, sm_scale, 32)
111+
out = fn(q, latent, bt, sl, sm_scale, 32, KV_LORA, o)
112+
maxerr = (out.float() - ref).abs().max().item()
113+
# --- sweep ---
114+
for blk in blocks:
115+
for tok in token_budgets:
116+
q, latent, bt, sl, o = make_inputs(args.bs, args.heads, blk, tok, device, dtype)
117+
ms = bench_one(fn, q, latent, bt, sl, o, blk)
118+
kv_bytes = args.bs * tok * KV_DIM * (2 if dtype != torch.float32 else 4)
119+
gbps = kv_bytes / (ms * 1e-3) / 1e9
120+
pct = gbps / PEAK_BW_GBPS * 100
121+
print(f"{name:<14}{blk:>4}{tok:>8}{ms:>9.3f}{gbps:>9.0f}{pct:>6.1f}%"
122+
f"{maxerr:>10.2e}")
123+
results.append({"kernel": name, "bs": args.bs, "heads": args.heads,
124+
"block": blk, "sel_tok": tok, "ms": round(ms, 4),
125+
"gbps": round(gbps, 1), "pct_peak": round(pct, 1),
126+
"maxerr": maxerr})
127+
with open(args.out, "w") as f:
128+
for r in results:
129+
f.write(json.dumps(r) + "\n")
130+
print(f"[bench] wrote {len(results)} rows to {args.out}")
131+
132+
133+
if __name__ == "__main__":
134+
main()
135+

cuda_mla/.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# JIT build artifacts (torch cpp_extension / ninja output)
2+
build/
3+
build_*/
4+
spec/build_*/
5+
**/__pycache__/
6+
*.so
7+
*.o

0 commit comments

Comments
 (0)