Skip to content

[FP8 BMM] aiter Triton kernel for small-B regime (V4 wo_a / grouped LoRA) #3000

@zufayu

Description

@zufayu

Motivation

The existing aiter Triton batched FP8 GEMM kernels are tuned for the MLA regime (B = num_heads ≈ 128). At small B (B=2, e.g. DeepSeek V4's grouped output LoRA on tp=8), they lose decisively to PyTorch's BF16 grouped einsum (which lowers to rocBLAS BMM).

Concrete case: DeepSeek V4 wo_a (output LoRA)

Shape per TP=8 rank: B = n_local_groups = 2, K = d_per_group = 4096, N = o_lora_rank = 1024, M = token batch (1-16 decode, up to ~1024 chunked prefill).

Single-op micro-bench on V4-Pro / 8x MI355X (gfx950):

M einsum (BF16, rocBLAS) batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant (FP8) G-loop gemm_a8w8_blockscale_preshuffle (FP8)
1 15 μs 35 μs (0.43x) 116 μs (0.13x)
16 14 μs 35 μs (0.43x) 60 μs (0.24x)
64 19 μs 37 μs (0.51x) 62 μs (0.30x)
256 19 μs 36 μs (0.54x) 58 μs (0.33x)
1024 37 μs 39 μs (0.96x) 55 μs (0.67x)

The FP8 paths are 1.5x – 7x slower than BF16 baseline, even though FP8 should theoretically halve the W bandwidth.

Why existing kernels lose

  • batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant: Grid is (B, M_tiles * N_tiles). At B=2 and decode M=4 (1 M-tile) with default BLOCK_SIZE_N=128 (8 N-tiles), total work-groups = 16 — about 5% utilization on 304 CU. Tuning the JSON config DB (verified via sweeps + @triton.autotune) does not close the gap; the B=2 grid bottleneck is fundamental.
  • G-loop gemm_a8w8_blockscale_preshuffle: Standard 2D GEMM per group, fully utilizes the chip on each call, but G_local kernel launches per layer (plus per-group act_quant) inflate eager-mode latency unacceptably at small M.
  • rocBLAS BMM (what torch.einsum lowers to): Wins decisively at B=2 small-M via vendor-tuned single-launch coverage.

Proposal: batched_gemm_a8w8_smallB_blockscale_a_per_token_group_prequant_w_per_block_scale

A new aiter Triton kernel optimized for small B + medium-large K + medium N, using on-disk per-128-block W scale directly (no requant, no precision loss).

Key design choices

  1. Collapse B into the grid's M dimension. Grid (B*ceil(M/BM)*ceil(N/BN),) with internal program-id decomposition. Lets the launcher dispatch all B*M*N work-groups in one wave, removing the B=2 bottleneck.
  2. Mandatory split-K. K=4096 with M small is bandwidth-bound; split-K=4 or 8 (with a small reduction kernel) gives 4-8x more work-groups. Already implemented in gemm_a8w8_blockscale.py (single-GEMM); just not in any batched kernel.
  3. Native per-block W scale support. Eliminates the per-block→scalar collapse that current batched FP8 kernels force callers to do at load time. BLOCK_SIZE_K = 128 is already constrained by per-token-group X quant, so per-block W scale loads cleanly.
  4. Preserve fused per-token-group X act-quant. Inline quant of X removes the need for a separate act-quant kernel call from callers.

Signature sketch

batched_gemm_a8w8_smallB_blockscale_a_per_token_group_prequant_w_per_block_scale(
    X,               # (B, M, K) BF16 — quantized to FP8 inline
    WQ,              # (B, N, K) FP8 — pre-quantized at load
    w_scale,         # (B, N//128, K//128) FP32 — per-128-block W scale
    group_size=128,  # X act-quant block size on K (= W block size on K)
    block_n=128,     # W block size on N
    transpose_bm=False,
    transpose_bm_in=False,
    splitK=4,        # required for small-B / small-M regimes
    dtype=torch.bfloat16,
)

Expected benefits

  • Wins or matches einsum at small-B FP8 BMM. Grid coverage + split-K should restore CU utilization that existing batched kernel sacrifices.
  • Eliminates one-time precision loss of per-block→scalar W-scale requant in MLA-style kernels.
  • No load-time dequant + requant danceprocess_weights_after_loading can be a no-op (LinearBase standard FP8 + shuffle handles the layer).
  • Reusable beyond V4 — any model with grouped output projections (small B, medium-large K) benefits.

Implementation effort

  • Triton kernel: ~150 LoC (start from batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, swap scalar w_scale load → per-block load, add 1D grid + split-K + reduction kernel).
  • Wrapper + autotune config: ~50 LoC.
  • Unit tests: ~80 LoC (compare vs reference einsum).
  • Estimated 1-2 engineer-days for a working version, +1 day for autotune baseline.

Related

Reproduce

Bench scripts in ROCm/ATOM PR #676's branch git history (commits 26a83b4..761c69e on feat/deepseek-v4-wo-a-fp8-bmm):

  • scripts/v4_wo_a_tune.py --microbench — einsum vs current FP8 BMM
  • scripts/v4_wo_a_microbench_gloop.py — einsum vs G-loop blockscale

Happy to draft the kernel PR if there's appetite — would appreciate guidance on the new file naming convention and whether split-K reduction should reuse _triton_kernels infrastructure.

cc the V4 work in ROCm/ATOM #650 and #676.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions