Motivation
The existing aiter Triton batched FP8 GEMM kernels are tuned for the MLA regime (B = num_heads ≈ 128). At small B (B=2, e.g. DeepSeek V4's grouped output LoRA on tp=8), they lose decisively to PyTorch's BF16 grouped einsum (which lowers to rocBLAS BMM).
Concrete case: DeepSeek V4 wo_a (output LoRA)
Shape per TP=8 rank: B = n_local_groups = 2, K = d_per_group = 4096, N = o_lora_rank = 1024, M = token batch (1-16 decode, up to ~1024 chunked prefill).
Single-op micro-bench on V4-Pro / 8x MI355X (gfx950):
| M |
einsum (BF16, rocBLAS) |
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant (FP8) |
G-loop gemm_a8w8_blockscale_preshuffle (FP8) |
| 1 |
15 μs |
35 μs (0.43x) |
116 μs (0.13x) |
| 16 |
14 μs |
35 μs (0.43x) |
60 μs (0.24x) |
| 64 |
19 μs |
37 μs (0.51x) |
62 μs (0.30x) |
| 256 |
19 μs |
36 μs (0.54x) |
58 μs (0.33x) |
| 1024 |
37 μs |
39 μs (0.96x) |
55 μs (0.67x) |
The FP8 paths are 1.5x – 7x slower than BF16 baseline, even though FP8 should theoretically halve the W bandwidth.
Why existing kernels lose
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant: Grid is (B, M_tiles * N_tiles). At B=2 and decode M=4 (1 M-tile) with default BLOCK_SIZE_N=128 (8 N-tiles), total work-groups = 16 — about 5% utilization on 304 CU. Tuning the JSON config DB (verified via sweeps + @triton.autotune) does not close the gap; the B=2 grid bottleneck is fundamental.
- G-loop
gemm_a8w8_blockscale_preshuffle: Standard 2D GEMM per group, fully utilizes the chip on each call, but G_local kernel launches per layer (plus per-group act_quant) inflate eager-mode latency unacceptably at small M.
- rocBLAS BMM (what
torch.einsum lowers to): Wins decisively at B=2 small-M via vendor-tuned single-launch coverage.
Proposal: batched_gemm_a8w8_smallB_blockscale_a_per_token_group_prequant_w_per_block_scale
A new aiter Triton kernel optimized for small B + medium-large K + medium N, using on-disk per-128-block W scale directly (no requant, no precision loss).
Key design choices
- Collapse B into the grid's M dimension. Grid
(B*ceil(M/BM)*ceil(N/BN),) with internal program-id decomposition. Lets the launcher dispatch all B*M*N work-groups in one wave, removing the B=2 bottleneck.
- Mandatory split-K. K=4096 with M small is bandwidth-bound; split-K=4 or 8 (with a small reduction kernel) gives 4-8x more work-groups. Already implemented in
gemm_a8w8_blockscale.py (single-GEMM); just not in any batched kernel.
- Native per-block W scale support. Eliminates the per-block→scalar collapse that current batched FP8 kernels force callers to do at load time.
BLOCK_SIZE_K = 128 is already constrained by per-token-group X quant, so per-block W scale loads cleanly.
- Preserve fused per-token-group X act-quant. Inline quant of X removes the need for a separate act-quant kernel call from callers.
Signature sketch
batched_gemm_a8w8_smallB_blockscale_a_per_token_group_prequant_w_per_block_scale(
X, # (B, M, K) BF16 — quantized to FP8 inline
WQ, # (B, N, K) FP8 — pre-quantized at load
w_scale, # (B, N//128, K//128) FP32 — per-128-block W scale
group_size=128, # X act-quant block size on K (= W block size on K)
block_n=128, # W block size on N
transpose_bm=False,
transpose_bm_in=False,
splitK=4, # required for small-B / small-M regimes
dtype=torch.bfloat16,
)
Expected benefits
- Wins or matches einsum at small-B FP8 BMM. Grid coverage + split-K should restore CU utilization that existing batched kernel sacrifices.
- Eliminates one-time precision loss of per-block→scalar W-scale requant in MLA-style kernels.
- No load-time dequant + requant dance —
process_weights_after_loading can be a no-op (LinearBase standard FP8 + shuffle handles the layer).
- Reusable beyond V4 — any model with grouped output projections (small B, medium-large K) benefits.
Implementation effort
- Triton kernel: ~150 LoC (start from
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, swap scalar w_scale load → per-block load, add 1D grid + split-K + reduction kernel).
- Wrapper + autotune config: ~50 LoC.
- Unit tests: ~80 LoC (compare vs reference einsum).
- Estimated 1-2 engineer-days for a working version, +1 day for autotune baseline.
Related
Reproduce
Bench scripts in ROCm/ATOM PR #676's branch git history (commits 26a83b4..761c69e on feat/deepseek-v4-wo-a-fp8-bmm):
scripts/v4_wo_a_tune.py --microbench — einsum vs current FP8 BMM
scripts/v4_wo_a_microbench_gloop.py — einsum vs G-loop blockscale
Happy to draft the kernel PR if there's appetite — would appreciate guidance on the new file naming convention and whether split-K reduction should reuse _triton_kernels infrastructure.
cc the V4 work in ROCm/ATOM #650 and #676.
Motivation
The existing aiter Triton batched FP8 GEMM kernels are tuned for the MLA regime (B = num_heads ≈ 128). At small B (B=2, e.g. DeepSeek V4's grouped output LoRA on tp=8), they lose decisively to PyTorch's BF16 grouped einsum (which lowers to rocBLAS BMM).
Concrete case: DeepSeek V4 wo_a (output LoRA)
Shape per TP=8 rank: B =
n_local_groups= 2, K =d_per_group= 4096, N =o_lora_rank= 1024, M = token batch (1-16 decode, up to ~1024 chunked prefill).Single-op micro-bench on V4-Pro / 8x MI355X (gfx950):
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(FP8)gemm_a8w8_blockscale_preshuffle(FP8)The FP8 paths are 1.5x – 7x slower than BF16 baseline, even though FP8 should theoretically halve the W bandwidth.
Why existing kernels lose
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant: Grid is(B, M_tiles * N_tiles). At B=2 and decode M=4 (1 M-tile) with default BLOCK_SIZE_N=128 (8 N-tiles), total work-groups = 16 — about 5% utilization on 304 CU. Tuning the JSON config DB (verified via sweeps +@triton.autotune) does not close the gap; the B=2 grid bottleneck is fundamental.gemm_a8w8_blockscale_preshuffle: Standard 2D GEMM per group, fully utilizes the chip on each call, but G_local kernel launches per layer (plus per-group act_quant) inflate eager-mode latency unacceptably at small M.torch.einsumlowers to): Wins decisively at B=2 small-M via vendor-tuned single-launch coverage.Proposal:
batched_gemm_a8w8_smallB_blockscale_a_per_token_group_prequant_w_per_block_scaleA new aiter Triton kernel optimized for
small B + medium-large K + medium N, using on-disk per-128-block W scale directly (no requant, no precision loss).Key design choices
(B*ceil(M/BM)*ceil(N/BN),)with internal program-id decomposition. Lets the launcher dispatch allB*M*Nwork-groups in one wave, removing the B=2 bottleneck.gemm_a8w8_blockscale.py(single-GEMM); just not in any batched kernel.BLOCK_SIZE_K = 128is already constrained by per-token-group X quant, so per-block W scale loads cleanly.Signature sketch
Expected benefits
process_weights_after_loadingcan be a no-op (LinearBase standard FP8 + shuffle handles the layer).Implementation effort
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, swap scalar w_scale load → per-block load, add 1D grid + split-K + reduction kernel).Related
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quantas a placeholder (kernel-level slower than einsum but unblocks memory savings); waiting on this aiter kernel for a real perf win.gemm_a8w8_blockscale_preshufflepath (closed in favor of [TRITON] Add Torch unit test reference to PA Prefill Triton Kernels #676 due to G+1 launch overhead being even worse at small M).Reproduce
Bench scripts in ROCm/ATOM PR #676's branch git history (commits 26a83b4..761c69e on
feat/deepseek-v4-wo-a-fp8-bmm):scripts/v4_wo_a_tune.py --microbench— einsum vs current FP8 BMMscripts/v4_wo_a_microbench_gloop.py— einsum vs G-loop blockscaleHappy to draft the kernel PR if there's appetite — would appreciate guidance on the new file naming convention and whether split-K reduction should reuse
_triton_kernelsinfrastructure.cc the V4 work in ROCm/ATOM #650 and #676.