|
| 1 | +// MLADecoder: flashinfer-style init/plan/run for ragged-batch MLA decode, |
| 2 | +// general across batch size and block size. |
| 3 | +// |
| 4 | +// __init__(bs, H, block_size, max_blocks, ...) -- ALLOCATE once + fix geometry. |
| 5 | +// A bs-aware policy sets the schedule knobs: target #active CTAs (one |
| 6 | +// MINB=3 wave ~ 3*SM, so LOW bs auto-gets many splits/request and HIGH bs |
| 7 | +// gets few), a chunk_min floor (avoid tiny-chunk overhead), a per-request |
| 8 | +// split cap, and the MINB to launch run() with. |
| 9 | +// plan(seqlens) -- POPULATE the load-balanced work queue from live seqlens. |
| 10 | +// run(q, latent, block_table, o, sm_scale) -- EXECUTE; dispatches the decode |
| 11 | +// kernel by (block_size, MINB). No seqlens => one plan() feeds all layers. |
| 12 | +// |
| 13 | +// Both plan() and run() are fixed-grid launches on the current stream with the |
| 14 | +// pre-allocated buffers => both CUDA-graph-capturable. |
| 15 | +#include "mla_ldm.cuh" |
| 16 | +#include <torch/extension.h> |
| 17 | +#include <algorithm> |
| 18 | + |
| 19 | +struct MLADecoder { |
| 20 | + int bs = 0, H = 0, block_size = 0, max_blocks = 0; |
| 21 | + int target = 0, target_ctas = 0, max_split_cap = 0, chunk_min = 0, minb = 3; |
| 22 | + int MTILES = 0, M = 0, sm_count = 0; |
| 23 | + torch::Tensor work_batch, work_kv_start, work_kv_end, work_offset, mid_o, mid_m, mid_l; |
| 24 | + |
| 25 | + // ---- init: bs-aware schedule policy + allocation. Negative knob args => auto. ---- |
| 26 | + MLADecoder(int bs_, int H_, int block_size_, int max_blocks_, |
| 27 | + int max_split_cap_ = -1, int chunk_min_ = -1, int minb_ = -1) { |
| 28 | + bs = bs_; H = H_; block_size = block_size_; max_blocks = max_blocks_; |
| 29 | + MTILES = (H + 15) / 16; M = MTILES * 16; |
| 30 | + int dev; cudaGetDevice(&dev); |
| 31 | + cudaDeviceProp prop; cudaGetDeviceProperties(&prop, dev); |
| 32 | + sm_count = prop.multiProcessorCount; |
| 33 | + |
| 34 | + // Achievable CTAs/SM is set by smem (the run_wq footprint, STAGES=2/NT=16), which |
| 35 | + // scales with M = MTILES*16: H<=16 (M=16) => ~55KB => 4 CTAs/SM; H<=32 (M=32) => |
| 36 | + // ~74KB => 3. We fill exactly one such wave: target active CTAs = ctas*SM, and |
| 37 | + // MINB=ctas (launch_bounds forces that occupancy; the small-M Oreg keeps it spill- |
| 38 | + // free). So H=16 auto-uses 4 CTAs/splits~4, H=20 uses 3 CTAs/splits~3. |
| 39 | + int smem_cta = 2 * 16 * ldm::HDP * 2 + M * ldm::HDP * 2 + M * 16 * 2 + 3 * M * 4; |
| 40 | + int smem_sm = (int)prop.sharedMemPerMultiprocessor; |
| 41 | + int ctas = std::max(1, std::min(6, smem_sm / smem_cta)); |
| 42 | + minb = (minb_ > 0) ? minb_ : ctas; // CTAs/SM occupancy target |
| 43 | + target = ctas * sm_count; // active CTAs to fill one wave |
| 44 | + chunk_min = (chunk_min_ > 0) ? chunk_min_ : 128; // don't split below 128 tokens |
| 45 | + // per-request cap: enough for low bs to fill a wave, bounded so one request can't |
| 46 | + // starve others on skew. ~ceil(target/bs) headroom, clamped to [ctas, target]. |
| 47 | + int auto_cap = std::min(target, std::max(ctas, 2 * ((target + bs - 1) / bs))); |
| 48 | + max_split_cap = (max_split_cap_ > 0) ? max_split_cap_ : auto_cap; |
| 49 | + // queue length: safe upper bound on sum(nsplits) (rounding + the >=1 clamp). |
| 50 | + target_ctas = std::max(target, bs) + bs; |
| 51 | + |
| 52 | + auto i32 = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); |
| 53 | + auto f32 = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); |
| 54 | + work_batch = torch::empty({target_ctas}, i32); |
| 55 | + work_kv_start = torch::empty({target_ctas}, i32); |
| 56 | + work_kv_end = torch::empty({target_ctas}, i32); |
| 57 | + work_offset = torch::empty({bs + 1}, i32); |
| 58 | + mid_o = torch::empty({target_ctas, M, ldm::CKV}, f32); |
| 59 | + mid_m = torch::empty({target_ctas, M}, f32); |
| 60 | + mid_l = torch::empty({target_ctas, M}, f32); |
| 61 | + } |
| 62 | + |
| 63 | + // ---- plan: populate the work queue from current seqlens. ---- |
| 64 | + void plan(torch::Tensor seqlens) { |
| 65 | + TORCH_CHECK(seqlens.size(0) == bs, "plan: seqlens batch ", seqlens.size(0), " != init bs ", bs); |
| 66 | + ldm::run_schedule_wq(seqlens, work_batch, work_kv_start, work_kv_end, work_offset, |
| 67 | + target, max_split_cap, chunk_min); |
| 68 | + } |
| 69 | + |
| 70 | + // ---- run: dispatch the decode kernel by (block_size, MINB). ---- |
| 71 | + void run(torch::Tensor q, torch::Tensor latent, torch::Tensor block_table, |
| 72 | + torch::Tensor o, double sm_scale) { |
| 73 | +#define RUN(BLK, MB) ldm::run_wq<BLK, 16, 4, 2, 1, MB>(q, latent, block_table, o, work_batch, \ |
| 74 | + work_kv_start, work_kv_end, work_offset, mid_o, mid_m, mid_l, sm_scale) |
| 75 | +// MINB is the occupancy target chosen in init from M (3 for H<=32, 4 for H<=16). |
| 76 | +#define DISPATCH_MB(BLK) do { \ |
| 77 | + if (minb <= 2) { RUN(BLK, 2); } else if (minb == 3) { RUN(BLK, 3); } \ |
| 78 | + else if (minb == 4) { RUN(BLK, 4); } else { RUN(BLK, 5); } } while (0) |
| 79 | + if (block_size == 64) { DISPATCH_MB(64); } |
| 80 | + else if (block_size == 32) { DISPATCH_MB(32); } |
| 81 | + else if (block_size == 16) { DISPATCH_MB(16); } |
| 82 | + else TORCH_CHECK(false, "MLADecoder: unsupported block_size ", block_size, " (need 16/32/64)"); |
| 83 | +#undef DISPATCH_MB |
| 84 | +#undef RUN |
| 85 | + } |
| 86 | +}; |
| 87 | + |
| 88 | +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| 89 | + pybind11::class_<MLADecoder>(m, "MLADecoder") |
| 90 | + .def(pybind11::init<int, int, int, int, int, int, int>(), |
| 91 | + pybind11::arg("bs"), pybind11::arg("H"), pybind11::arg("block_size"), |
| 92 | + pybind11::arg("max_blocks"), pybind11::arg("max_split_cap") = -1, |
| 93 | + pybind11::arg("chunk_min") = -1, pybind11::arg("minb") = -1) |
| 94 | + .def("plan", &MLADecoder::plan, pybind11::arg("seqlens")) |
| 95 | + .def("run", &MLADecoder::run, |
| 96 | + pybind11::arg("q"), pybind11::arg("latent"), pybind11::arg("block_table"), |
| 97 | + pybind11::arg("o"), pybind11::arg("sm_scale")) |
| 98 | + .def_readonly("target", &MLADecoder::target) |
| 99 | + .def_readonly("target_ctas", &MLADecoder::target_ctas) |
| 100 | + .def_readonly("max_split_cap", &MLADecoder::max_split_cap) |
| 101 | + .def_readonly("chunk_min", &MLADecoder::chunk_min) |
| 102 | + .def_readonly("minb", &MLADecoder::minb); |
| 103 | +} |
0 commit comments