Skip to content

Commit 11f4b32

Browse files
TimDettmersclaude
andcommitted
Optimize MMA kernel for small M: TILE_N=64 + multi-block-per-SM k_splits
For M<=16, the MMA kernel now uses TILE_N=64 (4 warps, 128 threads) instead of TILE_N=128 (8 warps, 256 threads). This doubles n_tiles for better SM coverage. Combined with an aggressive k_splits heuristic targeting 4 blocks per SM, occupancy jumps from 8% to ~28%. Key changes: - Template kbit_gemm_prod on TILE_N_VAL (default 128, use 64 for M<=16) - Derive NUM_WARPS and COLS_PER_WARP from TILE_N instead of hardcoding - k_splits heuristic targets 4 blocks/SM for TILE_N=64 (128-thread blocks) - Python workspace allocation uses TILE_N=64 worst case for tile_counters Benchmark (ncu, RTX 4090, dense_down K=5120 N=2048): k=2 M=4: 10.34 us (vs scalar GEMV 17.82 us = 1.72x faster) k=4 M=3: 12.93 us (vs scalar GEMV 16.58 us = 1.28x faster) Also attempted dequant-to-shmem (Phase 2) but reverted — serializing the full dequant pass before MMA eliminates pipeline interleaving, resulting in 2.6x regression. Inline dequant is superior. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 05638f5 commit 11f4b32

File tree

2 files changed

+147
-156
lines changed

2 files changed

+147
-156
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,17 +1042,16 @@ def _(
10421042
torch._check(B_packed.dtype == torch.int32, lambda: f"B_packed must be int32, got {B_packed.dtype}")
10431043
torch._check(B_absmax.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax.dtype}")
10441044
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
1045-
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
1045+
torch._check(N % 64 == 0, lambda: f"N ({N}) must be divisible by 64")
10461046
torch._check(k_chunks >= 1, lambda: f"k_chunks must be >= 1, got {k_chunks}")
10471047

10481048
M = A.shape[0]
10491049
C = torch.empty(M, N, device=A.device, dtype=A.dtype)
10501050

1051-
# The persistent kernel auto-selects k_splits internally. When
1052-
# k_splits > 1, it needs a zeroed fp32 workspace and tile counters.
1053-
# Always allocate these since the C++ decides at runtime.
1051+
# The persistent kernel auto-selects k_splits and TILE_N internally.
1052+
# TILE_N=64 for M<=16 gives more tiles; allocate for worst case.
10541053
TILE_M = 16
1055-
TILE_N = 128
1054+
TILE_N = 64 # worst case (most tiles)
10561055
m_tiles = (M + TILE_M - 1) // TILE_M
10571056
n_tiles = N // TILE_N
10581057

0 commit comments

Comments
 (0)