Skip to content

Commit 23f92e5

Browse files
TimDettmersclaude
andcommitted
Grouped MMA: TILE_N=64 + k_splits for small-N MoE shapes
For moe_gu (K=2048, N=512), the old TILE_N=128 gave only 4 N-tiles per expert × 8 experts = 32 blocks — 25% SM utilization on 128-SM GPU. Now uses TILE_N=64 (128 threads, 4 warps) when m_blocks==1, doubling N-tiles. Combined with auto k_splits that splits K into chunks when MN-tiles are insufficient, achieves full SM occupancy. Results: moe_gu drops from constant 26 us to 9-14 us (2.6-2.9x faster). Per-block total at k=4 M=5-8 improves from 1.03x to 1.24-1.36x vs fp16. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent df55daf commit 23f92e5

File tree

3 files changed

+188
-90
lines changed

3 files changed

+188
-90
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1096,11 +1096,28 @@ def _(
10961096
torch._check(B_absmax_all.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax_all.dtype}")
10971097
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
10981098
torch._check(expert_offsets.dtype == torch.int32, lambda: f"expert_offsets must be int32, got {expert_offsets.dtype}")
1099-
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
1099+
torch._check(N % 64 == 0, lambda: f"N ({N}) must be divisible by 64")
11001100

11011101
total_M = A_concat.shape[0]
11021102
C_concat = torch.empty(total_M, N, device=A_concat.device, dtype=A_concat.dtype)
11031103

1104+
# Workspace for split-K atomicAdd reduction (zeroed each call)
1105+
C_workspace = torch.zeros(total_M, N, device=A_concat.device, dtype=torch.float32)
1106+
# Tile counters for split-K last-block detection
1107+
# Upper bound: num_experts * max_m_tiles * max_n_tiles
1108+
m_blocks = 1
1109+
if max_M > 48:
1110+
m_blocks = 4
1111+
elif max_M > 32:
1112+
m_blocks = 3
1113+
elif max_M > 16:
1114+
m_blocks = 2
1115+
tile_n = 64 if (m_blocks == 1 and N % 64 == 0) else 128
1116+
n_tiles = N // tile_n
1117+
m_tiles = (max_M + m_blocks * 16 - 1) // (m_blocks * 16)
1118+
mn_tiles = num_experts * m_tiles * n_tiles
1119+
tile_counters = torch.zeros(mn_tiles, device=A_concat.device, dtype=torch.int32)
1120+
11041121
dtype_suffix = "fp16" if A_concat.dtype == torch.float16 else "bf16"
11051122

11061123
with _cuda_device_of(A_concat):
@@ -1111,6 +1128,8 @@ def _(
11111128
get_ptr(B_absmax_all),
11121129
get_ptr(codebook),
11131130
get_ptr(C_concat),
1131+
get_ptr(C_workspace),
1132+
get_ptr(tile_counters),
11141133
get_ptr(expert_offsets),
11151134
ct.c_int(K_dim),
11161135
ct.c_int(N),

0 commit comments

Comments
 (0)