Skip to content

Commit 30adc16

Browse files
author
Zhuoming Chen
committed
glm5
1 parent ecce53a commit 30adc16

32 files changed

Lines changed: 2895 additions & 247 deletions

cuda_mla/PROGRESS.md

Lines changed: 203 additions & 0 deletions
Large diffs are not rendered by default.

cuda_mla/REPORT.md

Lines changed: 197 additions & 113 deletions
Large diffs are not rendered by default.

cuda_mla/spec/bs_sweep.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Validate MLADecoder is general across bs and block_size (vs Triton).
2+
Correctness + throughput on uniform and ragged, bs in {8,32,64,128,256}, blk in {32,64}."""
3+
import os, statistics, random, torch
4+
from torch.utils.cpp_extension import load
5+
from vortex_torch.engine.sgl.attention_backend.triton_mla_kernel import (
6+
decode_blocktable_mla_opt, decode_blocktable_mla_split)
7+
HERE = "cuda_mla/spec"; KV_DIM, KV_LORA, H = 576, 512, 20; sm = 1.0 / (KV_DIM ** 0.5)
8+
bd = HERE + "/build_decoder"; os.makedirs(bd, exist_ok=True)
9+
mod = load(name="vortex_mla_decoder", sources=[HERE + "/mla_decoder.cu"],
10+
extra_cuda_cflags=["-O3", "-arch=sm_100a", "--use_fast_math", "-lineinfo"],
11+
extra_include_paths=[HERE], build_directory=bd, verbose=False)
12+
13+
def mk(bs, blk, sls):
14+
maxtok = int(max(sls)); nb = (maxtok + blk - 1) // blk; npg = bs * nb
15+
latent = torch.randn(npg * blk, KV_DIM, device='cuda', dtype=torch.bfloat16)
16+
bt = torch.randperm(npg, device='cuda', dtype=torch.int32).view(bs, nb).contiguous()
17+
sl = torch.tensor(sls, device='cuda', dtype=torch.int32)
18+
q = torch.randn(bs, H, KV_DIM, device='cuda', dtype=torch.bfloat16)
19+
return q, latent, bt, sl, nb
20+
21+
def ref(q, latent, bt, sl, blk):
22+
bs = q.size(0); out = torch.empty(bs, H, KV_LORA, device='cuda', dtype=torch.float32)
23+
qf, lf = q.float(), latent.float()
24+
for b in range(bs):
25+
s = int(sl[b]); nb = (s + blk - 1) // blk
26+
rows = [torch.arange(int(bt[b, j]) * blk, int(bt[b, j]) * blk + blk, device='cuda') for j in range(nb)]
27+
slots = torch.cat(rows)[:s]; k = lf[slots]
28+
out[b] = torch.softmax((qf[b] @ k.t()) * sm, -1) @ k[:, :KV_LORA]
29+
return out
30+
31+
def bench(call, q, latent, bt, sl, reps=8):
32+
o = torch.empty(q.size(0), H, KV_LORA, device='cuda', dtype=torch.bfloat16); vals = []
33+
for _ in range(reps):
34+
for _ in range(15): call(q, latent, bt, sl, o)
35+
torch.cuda.synchronize()
36+
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True); s.record()
37+
for _ in range(40): call(q, latent, bt, sl, o)
38+
e.record(); torch.cuda.synchronize()
39+
vals.append(int(sl.sum()) * KV_DIM * 2 / ((s.elapsed_time(e) / 40) * 1e-3) / 1e9)
40+
return statistics.median(vals)
41+
42+
random.seed(0)
43+
for blk in (32, 64):
44+
print(f"\n========== block_size={blk} (GB/s; mine/triton_best) ==========")
45+
print(f"{'bs':>5} {'pattern':>10} | {'splits':>6} {'mine':>6} {'tri_sp':>6} {'tri_ks':>6} | ratio err")
46+
for bs in (8, 32, 64, 128, 256):
47+
for pat in ("uniform", "ragged"):
48+
if pat == "uniform": sls = [2048] * bs
49+
else: sls = [random.choice([256, 512, 1024, 2048, 4096]) for _ in range(bs)]
50+
nb = (max(sls) + blk - 1) // blk
51+
q, latent, bt, sl, _ = mk(bs, blk, sls)
52+
dec = mod.MLADecoder(bs, H, blk, nb)
53+
o = torch.empty(bs, H, KV_LORA, device='cuda', dtype=torch.bfloat16)
54+
dec.plan(sl); dec.run(q, latent, bt, o, sm); torch.cuda.synchronize()
55+
err = (o.float() - ref(q, latent, bt, sl, blk)).abs().max().item()
56+
me = bench(lambda q,l,bt,sl,o: (dec.plan(sl), dec.run(q,l,bt,o,sm)), q, latent, bt, sl)
57+
tsp = bench(lambda q,l,bt,sl,o: decode_blocktable_mla_opt(q,l,bt,sl,sm,blk,KV_LORA,o), q, latent, bt, sl)
58+
tks = bench(lambda q,l,bt,sl,o: decode_blocktable_mla_split(q,l,bt,sl,sm,blk,KV_LORA,o), q, latent, bt, sl)
59+
tb = max(tsp, tks)
60+
tag = "OK" if err < 3e-2 else "FAIL<<"
61+
print(f"{bs:>5} {pat:>10} | {dec.target_ctas:>6} {me:>6.0f} {tsp:>6.0f} {tks:>6.0f} | {me/tb:.2f}x {err:.1e} {tag}")

cuda_mla/spec/final_h2h.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""FINAL apples-to-apples on one empty GPU: the delivered kernel (k_h20_bs128_blk64
2+
run() = bf16-O + MINB3 + sp3 + vectorized Q-load) vs Triton best."""
3+
import os, statistics, torch
4+
from torch.utils.cpp_extension import load
5+
from vortex_torch.engine.sgl.attention_backend.triton_mla_kernel import (
6+
decode_blocktable_mla_opt, decode_blocktable_mla_split)
7+
HERE = "cuda_mla/spec"; KV_DIM, KV_LORA, H = 576, 512, 20; sm = 1.0 / (KV_DIM ** 0.5)
8+
bd = HERE + "/build_k_h20_bs128_blk64"; os.makedirs(bd, exist_ok=True)
9+
mod = load(name="vortex_k_h20_bs128_blk64", sources=[HERE + "/k_h20_bs128_blk64.cu"],
10+
extra_cuda_cflags=["-O3", "-arch=sm_100a", "--use_fast_math", "-lineinfo"],
11+
extra_include_paths=[HERE], build_directory=bd, verbose=False)
12+
13+
def mk(bs, blk, tok, ragged=False):
14+
nb = (tok + blk - 1) // blk; npg = bs * nb
15+
latent = torch.randn(npg * blk, KV_DIM, device='cuda', dtype=torch.bfloat16)
16+
bt = torch.randperm(npg, device='cuda', dtype=torch.int32).view(bs, nb).contiguous()
17+
sl = (torch.randint(tok // 2, tok + 1, (bs,), device='cuda', dtype=torch.int32) if ragged
18+
else torch.full((bs,), tok, device='cuda', dtype=torch.int32))
19+
q = torch.randn(bs, H, KV_DIM, device='cuda', dtype=torch.bfloat16)
20+
return q, latent, bt, sl
21+
22+
def bench(call, bs, blk, tok, reps=12, ragged=False):
23+
vals = []
24+
for _ in range(reps):
25+
q, latent, bt, sl = mk(bs, blk, tok, ragged)
26+
o = torch.empty(bs, H, KV_LORA, device='cuda', dtype=torch.bfloat16)
27+
f = lambda: call(q, latent, bt, sl, o)
28+
for _ in range(20): f()
29+
torch.cuda.synchronize()
30+
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True); s.record()
31+
for _ in range(50): f()
32+
e.record(); torch.cuda.synchronize()
33+
vals.append(bs * tok * KV_DIM * 2 / ((s.elapsed_time(e) / 50) * 1e-3) / 1e9)
34+
return statistics.median(vals)
35+
36+
mine = lambda q,l,bt,sl,o: mod.run(q,l,bt,sl,o,sm)
37+
trsp = lambda q,l,bt,sl,o: decode_blocktable_mla_opt(q,l,bt,sl,sm,64,KV_LORA,o)
38+
trks = lambda q,l,bt,sl,o: decode_blocktable_mla_split(q,l,bt,sl,sm,64,KV_LORA,o)
39+
40+
print("=== FINAL h20 bs=128 blk=64 (one empty GPU, GB/s) ===")
41+
print(f"{'sel':>8} {'mine':>7} {'tri_sp':>7} {'tri_ks':>7} mine/best")
42+
for tok in (1024, 2048, 3072, 4096):
43+
m = bench(mine, 128, 64, tok); a = bench(trsp, 128, 64, tok); k = bench(trks, 128, 64, tok)
44+
print(f"{tok:>8} {m:>7.0f} {a:>7.0f} {k:>7.0f} {m/max(a,k):.3f}")
45+
m = bench(mine, 128, 64, 2048, ragged=True); a = bench(trsp, 128, 64, 2048, ragged=True)
46+
print(f"{'ragged':>8} {m:>7.0f} {a:>7.0f} {'':>7} {m/a:.3f}")

cuda_mla/spec/full_sweep.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Full bs x blk table: MLADecoder (plan+run, bs-general work-queue) vs Triton best.
2+
bs in {1..8} U {8*i, 1<=i<=16}; blk in {16,32,64}; sel=2048 uniform. Writes CSV + grid."""
3+
import os, statistics, json, torch
4+
from torch.utils.cpp_extension import load
5+
from vortex_torch.engine.sgl.attention_backend.triton_mla_kernel import (
6+
decode_blocktable_mla_opt, decode_blocktable_mla_split)
7+
HERE = "cuda_mla/spec"; KV_DIM, KV_LORA, H = 576, 512, 20; sm = 1.0 / (KV_DIM ** 0.5)
8+
bd = HERE + "/build_decoder"; os.makedirs(bd, exist_ok=True)
9+
mod = load(name="vortex_mla_decoder", sources=[HERE + "/mla_decoder.cu"],
10+
extra_cuda_cflags=["-O3", "-arch=sm_100a", "--use_fast_math", "-lineinfo"],
11+
extra_include_paths=[HERE], build_directory=bd, verbose=False)
12+
SEL = 2048
13+
14+
def mk(bs, blk, tok):
15+
nb = (tok + blk - 1) // blk; npg = bs * nb
16+
latent = torch.randn(npg * blk, KV_DIM, device='cuda', dtype=torch.bfloat16)
17+
bt = torch.randperm(npg, device='cuda', dtype=torch.int32).view(bs, nb).contiguous()
18+
sl = torch.full((bs,), tok, device='cuda', dtype=torch.int32)
19+
q = torch.randn(bs, H, KV_DIM, device='cuda', dtype=torch.bfloat16)
20+
return q, latent, bt, sl, nb
21+
22+
def bench(call, q, latent, bt, sl, reps=6):
23+
o = torch.empty(q.size(0), H, KV_LORA, device='cuda', dtype=torch.bfloat16); vals = []
24+
for _ in range(reps):
25+
for _ in range(12): call(q, latent, bt, sl, o)
26+
torch.cuda.synchronize()
27+
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True); s.record()
28+
for _ in range(40): call(q, latent, bt, sl, o)
29+
e.record(); torch.cuda.synchronize()
30+
vals.append(int(sl.sum()) * KV_DIM * 2 / ((s.elapsed_time(e) / 40) * 1e-3) / 1e9)
31+
return statistics.median(vals)
32+
33+
bss = sorted(set(list(range(1, 9)) + [8 * i for i in range(1, 17)]))
34+
rows = []
35+
for blk in (16, 32, 64):
36+
for bs in bss:
37+
nb = (SEL + blk - 1) // blk
38+
q, latent, bt, sl, _ = mk(bs, blk, SEL)
39+
dec = mod.MLADecoder(bs, H, blk, nb)
40+
me = bench(lambda q,l,bt,sl,o: (dec.plan(sl), dec.run(q,l,bt,o,sm)), q, latent, bt, sl)
41+
tsp = bench(lambda q,l,bt,sl,o: decode_blocktable_mla_opt(q,l,bt,sl,sm,blk,KV_LORA,o), q, latent, bt, sl)
42+
tks = bench(lambda q,l,bt,sl,o: decode_blocktable_mla_split(q,l,bt,sl,sm,blk,KV_LORA,o), q, latent, bt, sl)
43+
tb = max(tsp, tks); best = 'sp' if tsp >= tks else 'ks'
44+
rows.append(dict(blk=blk, bs=bs, mine=me, tri_sp=tsp, tri_ks=tks, ratio=me/tb, tbest=best))
45+
print(f"blk={blk:2d} bs={bs:3d}: mine={me:5.0f} tri_sp={tsp:5.0f} tri_ks={tks:5.0f} ratio={me/tb:.2f}x")
46+
47+
json.dump(rows, open(HERE + "/full_sweep.json", "w"), indent=0)
48+
# pretty grid per blk
49+
print("\n\n##### TABLE (GB/s mine | ratio vs Triton-best), sel=2048 uniform #####")
50+
for blk in (16, 32, 64):
51+
print(f"\n### block_size = {blk} ###")
52+
print(f"{'bs':>4} | {'mine':>6} {'tri_sp':>6} {'tri_ks':>6} | {'ratio':>6}")
53+
for r in rows:
54+
if r['blk'] == blk:
55+
print(f"{r['bs']:>4} | {r['mine']:>6.0f} {r['tri_sp']:>6.0f} {r['tri_ks']:>6.0f} | {r['ratio']:>5.2f}x")

cuda_mla/spec/k_h16.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// H=16 standalone flagship (MTILES=1, no head padding). vs H=20: M=16 halves the Q
2+
// smem (~18.7KB) and Oreg (32 regs) => ~55KB smem => 4 CTAs/SM (vs 3) => 44% DRAM.
3+
// Optimum: NWARPS=4, STAGES=2, MINB=5, splits=4 (~3000-3023 GB/s @ bs128, 1.1-1.9x Triton).
4+
// (MINB=5: smem caps occupancy at 4 CTAs, but the lower reg target schedules ~1% better
5+
// than MINB=4; the tiny M=16 Oreg keeps it spill-free.) BLK template must match block_size.
6+
#include "mla_ldm.cuh"
7+
#define V(nm, BLK, MB) \
8+
void nm(torch::Tensor q, torch::Tensor l, torch::Tensor bt, torch::Tensor sl, \
9+
torch::Tensor o, double s, int sp) { ldm::launch<BLK,16,4,2,1,MB>(q,l,bt,sl,o,s,sp); }
10+
V(b64_b4,64,4) V(b64_b5,64,5) V(b32_b4,32,4) V(b32_b5,32,5) V(b16_b4,16,4) V(b16_b5,16,5)
11+
// defaults: the measured optimum per block size (MINB=5, splits=4).
12+
void run64(torch::Tensor q,torch::Tensor l,torch::Tensor bt,torch::Tensor sl,torch::Tensor o,double s){ ldm::launch<64,16,4,2,1,5>(q,l,bt,sl,o,s,4); }
13+
void run32(torch::Tensor q,torch::Tensor l,torch::Tensor bt,torch::Tensor sl,torch::Tensor o,double s){ ldm::launch<32,16,4,2,1,5>(q,l,bt,sl,o,s,4); }
14+
void run16(torch::Tensor q,torch::Tensor l,torch::Tensor bt,torch::Tensor sl,torch::Tensor o,double s){ ldm::launch<16,16,4,2,1,5>(q,l,bt,sl,o,s,4); }
15+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16+
m.def("run64",&run64); m.def("run32",&run32); m.def("run16",&run16);
17+
m.def("b64_b4",&b64_b4); m.def("b64_b5",&b64_b5); m.def("b32_b4",&b32_b4);
18+
m.def("b32_b5",&b32_b5); m.def("b16_b4",&b16_b4); m.def("b16_b5",&b16_b5);
19+
}

cuda_mla/spec/k_h20_bs128_blk64.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
// H=20, bs=128, block_size=64. ldmatrix + reg-O + reg-softmax + cp.async + split-KV.
1+
// H=20, bs=128, block_size=64. ldmatrix + bf16-packed reg-O + reg-softmax +
2+
// cp.async + split-KV. WINNING CONFIG (ncu-tuned): NWARPS=4, STAGES=2, MINB=3
3+
// (3 CTAs/SM => 18.75% occ, the occupancy wall the bandwidth-bound decode hit),
4+
// splits=3 (populates the 3-block capacity). 2059 GB/s @ sel=2048 vs Triton 1971.
25
#include "mla_ldm.cuh"
3-
#define V(nm, NW, ST) \
6+
#define V(nm, NW, ST, MB) \
47
void nm(torch::Tensor q, torch::Tensor l, torch::Tensor bt, torch::Tensor sl, \
5-
torch::Tensor o, double s, int sp) { ldm::launch<64,16,NW,ST>(q,l,bt,sl,o,s,sp); }
6-
V(r_w4_s1,4,1) V(r_w4_s2,4,2) V(r_w8_s1,8,1) V(r_w8_s2,8,2)
8+
torch::Tensor o, double s, int sp) { ldm::launch<64,16,NW,ST,1,MB>(q,l,bt,sl,o,s,sp); }
9+
V(r_w4_s2_b2,4,2,2) V(r_w4_s2_b3,4,2,3) V(r_w8_s2_b2,8,2,2) V(r_w8_s2_b3,8,2,3)
10+
// default: the measured optimum for h20/bs128/blk64 (MINB=3, splits=3).
711
void run(torch::Tensor q, torch::Tensor l, torch::Tensor bt, torch::Tensor sl,
8-
torch::Tensor o, double s) { ldm::launch<64,16,4,2>(q,l,bt,sl,o,s,2); }
12+
torch::Tensor o, double s) { ldm::launch<64,16,4,2,1,3>(q,l,bt,sl,o,s,3); }
913
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10-
m.def("run",&run); m.def("r_w4_s1",&r_w4_s1); m.def("r_w4_s2",&r_w4_s2);
11-
m.def("r_w8_s1",&r_w8_s1); m.def("r_w8_s2",&r_w8_s2);
14+
m.def("run",&run); m.def("r_w4_s2_b2",&r_w4_s2_b2); m.def("r_w4_s2_b3",&r_w4_s2_b3);
15+
m.def("r_w8_s2_b2",&r_w8_s2_b2); m.def("r_w8_s2_b3",&r_w8_s2_b3);
1216
}

cuda_mla/spec/mla_decoder.cu

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// MLADecoder: flashinfer-style init/plan/run for ragged-batch MLA decode,
2+
// general across batch size and block size.
3+
//
4+
// __init__(bs, H, block_size, max_blocks, ...) -- ALLOCATE once + fix geometry.
5+
// A bs-aware policy sets the schedule knobs: target #active CTAs (one
6+
// MINB=3 wave ~ 3*SM, so LOW bs auto-gets many splits/request and HIGH bs
7+
// gets few), a chunk_min floor (avoid tiny-chunk overhead), a per-request
8+
// split cap, and the MINB to launch run() with.
9+
// plan(seqlens) -- POPULATE the load-balanced work queue from live seqlens.
10+
// run(q, latent, block_table, o, sm_scale) -- EXECUTE; dispatches the decode
11+
// kernel by (block_size, MINB). No seqlens => one plan() feeds all layers.
12+
//
13+
// Both plan() and run() are fixed-grid launches on the current stream with the
14+
// pre-allocated buffers => both CUDA-graph-capturable.
15+
#include "mla_ldm.cuh"
16+
#include <torch/extension.h>
17+
#include <algorithm>
18+
19+
struct MLADecoder {
20+
int bs = 0, H = 0, block_size = 0, max_blocks = 0;
21+
int target = 0, target_ctas = 0, max_split_cap = 0, chunk_min = 0, minb = 3;
22+
int MTILES = 0, M = 0, sm_count = 0;
23+
torch::Tensor work_batch, work_kv_start, work_kv_end, work_offset, mid_o, mid_m, mid_l;
24+
25+
// ---- init: bs-aware schedule policy + allocation. Negative knob args => auto. ----
26+
MLADecoder(int bs_, int H_, int block_size_, int max_blocks_,
27+
int max_split_cap_ = -1, int chunk_min_ = -1, int minb_ = -1) {
28+
bs = bs_; H = H_; block_size = block_size_; max_blocks = max_blocks_;
29+
MTILES = (H + 15) / 16; M = MTILES * 16;
30+
int dev; cudaGetDevice(&dev);
31+
cudaDeviceProp prop; cudaGetDeviceProperties(&prop, dev);
32+
sm_count = prop.multiProcessorCount;
33+
34+
// Achievable CTAs/SM is set by smem (the run_wq footprint, STAGES=2/NT=16), which
35+
// scales with M = MTILES*16: H<=16 (M=16) => ~55KB => 4 CTAs/SM; H<=32 (M=32) =>
36+
// ~74KB => 3. We fill exactly one such wave: target active CTAs = ctas*SM, and
37+
// MINB=ctas (launch_bounds forces that occupancy; the small-M Oreg keeps it spill-
38+
// free). So H=16 auto-uses 4 CTAs/splits~4, H=20 uses 3 CTAs/splits~3.
39+
int smem_cta = 2 * 16 * ldm::HDP * 2 + M * ldm::HDP * 2 + M * 16 * 2 + 3 * M * 4;
40+
int smem_sm = (int)prop.sharedMemPerMultiprocessor;
41+
int ctas = std::max(1, std::min(6, smem_sm / smem_cta));
42+
minb = (minb_ > 0) ? minb_ : ctas; // CTAs/SM occupancy target
43+
target = ctas * sm_count; // active CTAs to fill one wave
44+
chunk_min = (chunk_min_ > 0) ? chunk_min_ : 128; // don't split below 128 tokens
45+
// per-request cap: enough for low bs to fill a wave, bounded so one request can't
46+
// starve others on skew. ~ceil(target/bs) headroom, clamped to [ctas, target].
47+
int auto_cap = std::min(target, std::max(ctas, 2 * ((target + bs - 1) / bs)));
48+
max_split_cap = (max_split_cap_ > 0) ? max_split_cap_ : auto_cap;
49+
// queue length: safe upper bound on sum(nsplits) (rounding + the >=1 clamp).
50+
target_ctas = std::max(target, bs) + bs;
51+
52+
auto i32 = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
53+
auto f32 = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
54+
work_batch = torch::empty({target_ctas}, i32);
55+
work_kv_start = torch::empty({target_ctas}, i32);
56+
work_kv_end = torch::empty({target_ctas}, i32);
57+
work_offset = torch::empty({bs + 1}, i32);
58+
mid_o = torch::empty({target_ctas, M, ldm::CKV}, f32);
59+
mid_m = torch::empty({target_ctas, M}, f32);
60+
mid_l = torch::empty({target_ctas, M}, f32);
61+
}
62+
63+
// ---- plan: populate the work queue from current seqlens. ----
64+
void plan(torch::Tensor seqlens) {
65+
TORCH_CHECK(seqlens.size(0) == bs, "plan: seqlens batch ", seqlens.size(0), " != init bs ", bs);
66+
ldm::run_schedule_wq(seqlens, work_batch, work_kv_start, work_kv_end, work_offset,
67+
target, max_split_cap, chunk_min);
68+
}
69+
70+
// ---- run: dispatch the decode kernel by (block_size, MINB). ----
71+
void run(torch::Tensor q, torch::Tensor latent, torch::Tensor block_table,
72+
torch::Tensor o, double sm_scale) {
73+
#define RUN(BLK, MB) ldm::run_wq<BLK, 16, 4, 2, 1, MB>(q, latent, block_table, o, work_batch, \
74+
work_kv_start, work_kv_end, work_offset, mid_o, mid_m, mid_l, sm_scale)
75+
// MINB is the occupancy target chosen in init from M (3 for H<=32, 4 for H<=16).
76+
#define DISPATCH_MB(BLK) do { \
77+
if (minb <= 2) { RUN(BLK, 2); } else if (minb == 3) { RUN(BLK, 3); } \
78+
else if (minb == 4) { RUN(BLK, 4); } else { RUN(BLK, 5); } } while (0)
79+
if (block_size == 64) { DISPATCH_MB(64); }
80+
else if (block_size == 32) { DISPATCH_MB(32); }
81+
else if (block_size == 16) { DISPATCH_MB(16); }
82+
else TORCH_CHECK(false, "MLADecoder: unsupported block_size ", block_size, " (need 16/32/64)");
83+
#undef DISPATCH_MB
84+
#undef RUN
85+
}
86+
};
87+
88+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
89+
pybind11::class_<MLADecoder>(m, "MLADecoder")
90+
.def(pybind11::init<int, int, int, int, int, int, int>(),
91+
pybind11::arg("bs"), pybind11::arg("H"), pybind11::arg("block_size"),
92+
pybind11::arg("max_blocks"), pybind11::arg("max_split_cap") = -1,
93+
pybind11::arg("chunk_min") = -1, pybind11::arg("minb") = -1)
94+
.def("plan", &MLADecoder::plan, pybind11::arg("seqlens"))
95+
.def("run", &MLADecoder::run,
96+
pybind11::arg("q"), pybind11::arg("latent"), pybind11::arg("block_table"),
97+
pybind11::arg("o"), pybind11::arg("sm_scale"))
98+
.def_readonly("target", &MLADecoder::target)
99+
.def_readonly("target_ctas", &MLADecoder::target_ctas)
100+
.def_readonly("max_split_cap", &MLADecoder::max_split_cap)
101+
.def_readonly("chunk_min", &MLADecoder::chunk_min)
102+
.def_readonly("minb", &MLADecoder::minb);
103+
}

0 commit comments

Comments
 (0)