|
| 1 | +"""Diagnostic: isolate the cost of (load-mask, where-mask, loop-peel) in the |
| 2 | +single-pass MLA kernel, in the benchmark regime (seqlen % BLOCK_N == 0, so the |
| 3 | +fully-unmasked path is numerically correct). Run under CUDA_VISIBLE_DEVICES. |
| 4 | +""" |
| 5 | +import statistics |
| 6 | +import torch |
| 7 | +import triton |
| 8 | +import triton.language as tl |
| 9 | +from bench_triton_mla import make_inputs, KV_DIM, KV_LORA |
| 10 | + |
| 11 | + |
| 12 | +@triton.jit |
| 13 | +def _diag_kernel( |
| 14 | + Q, K_Buffer, V_Buffer, sm_scale, Seqlens, Block_Table, O, |
| 15 | + stride_qbs, stride_qh, stride_buf_kbs, stride_buf_vbs, |
| 16 | + stride_obs, stride_oh, stride_bt_b, |
| 17 | + q_head_num: tl.constexpr, BLOCK_SIZE: tl.constexpr, |
| 18 | + BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, BLOCK_DV: tl.constexpr, |
| 19 | + BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, |
| 20 | + USE_LOAD_MASK: tl.constexpr, USE_WHERE: tl.constexpr, PEEL: tl.constexpr, |
| 21 | + WITH_TAIL: tl.constexpr, |
| 22 | +): |
| 23 | + cur_batch = tl.program_id(0) |
| 24 | + cur_head_id = tl.program_id(1) |
| 25 | + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) |
| 26 | + mask_h = cur_head < q_head_num |
| 27 | + |
| 28 | + offs_d = tl.arange(0, BLOCK_DMODEL) |
| 29 | + offs_dv = tl.arange(0, BLOCK_DV) |
| 30 | + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) |
| 31 | + seqlen = tl.load(Seqlens + cur_batch) |
| 32 | + |
| 33 | + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] |
| 34 | + q = tl.load(Q + offs_q, mask=mask_h[:, None], other=0.0) |
| 35 | + off_qpe = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] |
| 36 | + qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0) |
| 37 | + |
| 38 | + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") |
| 39 | + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) |
| 40 | + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) |
| 41 | + bt_base = Block_Table + cur_batch * stride_bt_b |
| 42 | + |
| 43 | + if PEEL: |
| 44 | + loop_end = (seqlen // BLOCK_N) * BLOCK_N |
| 45 | + else: |
| 46 | + loop_end = seqlen |
| 47 | + |
| 48 | + for start_n in range(0, loop_end, BLOCK_N): |
| 49 | + offs_n = start_n + tl.arange(0, BLOCK_N) |
| 50 | + valid = offs_n < seqlen |
| 51 | + if USE_LOAD_MASK: |
| 52 | + page = tl.load(bt_base + (offs_n // BLOCK_SIZE), mask=valid, other=0) |
| 53 | + else: |
| 54 | + page = tl.load(bt_base + (offs_n // BLOCK_SIZE)) |
| 55 | + kv_loc = page * BLOCK_SIZE + (offs_n % BLOCK_SIZE) |
| 56 | + if USE_LOAD_MASK: |
| 57 | + k = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_d[:, None], |
| 58 | + mask=valid[None, :], other=0.0) |
| 59 | + kpe = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_dpe[:, None], |
| 60 | + mask=valid[None, :], other=0.0) |
| 61 | + v = tl.load(V_Buffer + kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :], |
| 62 | + mask=valid[:, None], other=0.0) |
| 63 | + else: |
| 64 | + k = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_d[:, None]) |
| 65 | + kpe = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_dpe[:, None]) |
| 66 | + v = tl.load(V_Buffer + kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :]) |
| 67 | + qk = tl.dot(q, k.to(q.dtype)) |
| 68 | + qk += tl.dot(qpe, kpe.to(qpe.dtype)) |
| 69 | + qk *= sm_scale |
| 70 | + if USE_WHERE: |
| 71 | + qk = tl.where(valid[None, :], qk, float("-inf")) |
| 72 | + n_e_max = tl.maximum(tl.max(qk, 1), e_max) |
| 73 | + re_scale = tl.exp(e_max - n_e_max) |
| 74 | + p = tl.exp(qk - n_e_max[:, None]) |
| 75 | + acc *= re_scale[:, None] |
| 76 | + acc += tl.dot(p.to(v.dtype), v) |
| 77 | + e_sum = e_sum * re_scale + tl.sum(p, 1) |
| 78 | + e_max = n_e_max |
| 79 | + |
| 80 | + if WITH_TAIL: |
| 81 | + # separate duplicated masked tail block (mirrors _fwd_blocktable_mla_kernel_opt) |
| 82 | + if loop_end < seqlen: |
| 83 | + offs_n = loop_end + tl.arange(0, BLOCK_N) |
| 84 | + valid = offs_n < seqlen |
| 85 | + page = tl.load(bt_base + (offs_n // BLOCK_SIZE), mask=valid, other=0) |
| 86 | + kv_loc = page * BLOCK_SIZE + (offs_n % BLOCK_SIZE) |
| 87 | + k = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_d[:, None], |
| 88 | + mask=valid[None, :], other=0.0) |
| 89 | + qk = tl.dot(q, k.to(q.dtype)) |
| 90 | + kpe = tl.load(K_Buffer + kv_loc[None, :] * stride_buf_kbs + offs_dpe[:, None], |
| 91 | + mask=valid[None, :], other=0.0) |
| 92 | + qk += tl.dot(qpe, kpe.to(qpe.dtype)) |
| 93 | + qk *= sm_scale |
| 94 | + qk = tl.where(valid[None, :], qk, float("-inf")) |
| 95 | + v = tl.load(V_Buffer + kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :], |
| 96 | + mask=valid[:, None], other=0.0) |
| 97 | + n_e_max = tl.maximum(tl.max(qk, 1), e_max) |
| 98 | + re_scale = tl.exp(e_max - n_e_max) |
| 99 | + p = tl.exp(qk - n_e_max[:, None]) |
| 100 | + acc *= re_scale[:, None] |
| 101 | + acc += tl.dot(p.to(v.dtype), v) |
| 102 | + e_sum = e_sum * re_scale + tl.sum(p, 1) |
| 103 | + e_max = n_e_max |
| 104 | + |
| 105 | + offs_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_dv[None, :] |
| 106 | + tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_h[:, None]) |
| 107 | + |
| 108 | + |
| 109 | +def run(q, latent, bt, sl, sm, blk, o, BLOCK_H, lm, wh, peel, tail=False, num_warps=8): |
| 110 | + bs, H, _ = q.shape |
| 111 | + grid = (bs, triton.cdiv(H, BLOCK_H)) |
| 112 | + _diag_kernel[grid]( |
| 113 | + q, latent, latent, sm, sl, bt, o, |
| 114 | + q.stride(0), q.stride(1), latent.stride(0), latent.stride(0), |
| 115 | + o.stride(0), o.stride(1), bt.stride(0), |
| 116 | + q_head_num=H, BLOCK_SIZE=blk, BLOCK_DMODEL=512, BLOCK_DPE=64, |
| 117 | + BLOCK_DV=512, BLOCK_N=blk, BLOCK_H=BLOCK_H, |
| 118 | + USE_LOAD_MASK=lm, USE_WHERE=wh, PEEL=peel, WITH_TAIL=tail, num_warps=num_warps, |
| 119 | + ) |
| 120 | + return o |
| 121 | + |
| 122 | + |
| 123 | +def bench(fn, q, latent, bt, sl, o, blk, iters=50, warmup=20): |
| 124 | + sm = 1.0 / (KV_DIM ** 0.5) |
| 125 | + for _ in range(warmup): |
| 126 | + fn(q, latent, bt, sl, sm, blk, o) |
| 127 | + torch.cuda.synchronize() |
| 128 | + s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True) |
| 129 | + s.record() |
| 130 | + for _ in range(iters): |
| 131 | + fn(q, latent, bt, sl, sm, blk, o) |
| 132 | + e.record(); torch.cuda.synchronize() |
| 133 | + return s.elapsed_time(e) / iters |
| 134 | + |
| 135 | + |
| 136 | +def gbps(q, latent, bt, sl, o, blk, BLOCK_H, lm, wh, peel, tail=False): |
| 137 | + bs, H, _ = q.shape |
| 138 | + tok = int(sl[0]) |
| 139 | + best = [] |
| 140 | + for _ in range(3): |
| 141 | + ms = bench(lambda *a: run(*a, BLOCK_H, lm, wh, peel, tail), q, latent, bt, sl, o, blk) |
| 142 | + best.append(bs * tok * KV_DIM * 2 / (ms * 1e-3) / 1e9) |
| 143 | + return statistics.median(best) |
| 144 | + |
| 145 | + |
| 146 | +if __name__ == "__main__": |
| 147 | + dev, dt = "cuda", torch.bfloat16 |
| 148 | + bs = 128 |
| 149 | + configs = [ |
| 150 | + # name, load_mask, where, peel, with_tail |
| 151 | + ("baseline single-loop lm+wh ", True, True, False, False), |
| 152 | + ("unmasked single-loop ", False, False, False, False), |
| 153 | + ("peel, NO tail block ", False, False, True, False), |
| 154 | + ("peel + SEPARATE tail block ", False, False, True, True), |
| 155 | + ] |
| 156 | + for H, BLOCK_H in [(16, 16), (20, 16), (20, 32)]: |
| 157 | + print(f"\n=== H={H} BLOCK_H={BLOCK_H} bs={bs} (median of 3 GB/s) ===") |
| 158 | + hdr = f"{'config':<30}" + "".join(f"{f'b{b}/t{t}':>11}" for b in (16, 64) for t in (1024, 4096)) |
| 159 | + print(hdr) |
| 160 | + for name, lm, wh, peel, tail in configs: |
| 161 | + row = f"{name:<30}" |
| 162 | + for blk in (16, 64): |
| 163 | + for tok in (1024, 4096): |
| 164 | + q, latent, bt, sl, o = make_inputs(bs, H, blk, tok, dev, dt) |
| 165 | + g = gbps(q, latent, bt, sl, o, blk, BLOCK_H, lm, wh, peel, tail) |
| 166 | + row += f"{g:>11.0f}" |
| 167 | + print(row) |
0 commit comments