Skip to content

Commit 6e76b52

Browse files
TimDettmersclaude
andcommitted
Optimize SM100 MoE pipeline: remove .item() syncs, add weighted gather, vectorize bias
- Remove GPU-CPU .item() syncs from _forward_batched hot path: derive total_tokens from x.shape[0], cache max_M after first call - Add fused weighted gather path (cmoe_weighted_gather_bf16) with optional token_ids/gating_weights kwargs for top-k MoE routing - Pre-allocate persistent buffers for gather workspace and scale swizzle constants in _batched_cache - Vectorize bias addition: replace per-expert .item() loop with repeat_interleave in _forward_grouped, broadcast add in _forward_batched - Add TestWeightedGather test class (shape, correctness, bias coverage) All 26 tests pass on B200. Benchmarks show ~3x speedup over BF16 bmm at decode sizes with 128 experts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9863ff9 commit 6e76b52

File tree

2 files changed

+263
-52
lines changed

2 files changed

+263
-52
lines changed

bitsandbytes/nn/modules.py

Lines changed: 126 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,15 @@ def _quantize_weights(self):
852852
requires_grad=False,
853853
)
854854

855-
def forward(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
855+
def forward(
856+
self,
857+
x: torch.Tensor,
858+
expert_offsets: torch.Tensor,
859+
*,
860+
token_ids: Optional[torch.Tensor] = None,
861+
gating_weights: Optional[torch.Tensor] = None,
862+
num_dest_tokens: Optional[int] = None,
863+
) -> torch.Tensor:
856864
"""Run NVFP4 GEMM across all experts.
857865
858866
Uses batched GEMM on SM_100 (datacenter Blackwell) or grouped GEMM
@@ -864,17 +872,30 @@ def forward(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor
864872
expert_offsets: Cumulative token offsets [num_experts + 1], int32.
865873
expert_offsets[i] is the starting token index for expert i.
866874
expert_offsets[-1] = total_tokens.
875+
token_ids: Optional mapping from assignment index to output token index
876+
[total_tokens] (int32). Required for weighted gather.
877+
gating_weights: Optional per-assignment gating weights [total_tokens] (float32).
878+
Required for weighted gather.
879+
num_dest_tokens: Number of unique destination tokens in the output.
880+
Required when token_ids and gating_weights are provided.
867881
868882
Returns:
869-
Output tensor [total_tokens, N] with expert results in the same token order.
883+
If token_ids and gating_weights are provided:
884+
Weighted output tensor [num_dest_tokens, N] with fused gather + weight + sum.
885+
Otherwise:
886+
Output tensor [total_tokens, N] with per-assignment expert results.
870887
"""
871888
if not self._quantized:
872889
self._quantize_weights()
873890

874891
major, _ = torch.cuda.get_device_capability(x.device)
875892
from bitsandbytes.cextension import lib
876893
if major == 10 and hasattr(lib, "cgemm_nvfp4_moe_sm100_init"):
877-
return self._forward_batched(x, expert_offsets)
894+
return self._forward_batched(
895+
x, expert_offsets,
896+
token_ids=token_ids, gating_weights=gating_weights,
897+
num_dest_tokens=num_dest_tokens,
898+
)
878899
return self._forward_grouped(x, expert_offsets)
879900

880901
def _forward_grouped(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
@@ -899,27 +920,35 @@ def _forward_grouped(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
899920
)
900921

901922
if self.bias is not None:
902-
for i in range(self.num_experts):
903-
start = expert_offsets[i].item()
904-
end = expert_offsets[i + 1].item()
905-
if end > start:
906-
out[start:end] = out[start:end] + self.bias[i].to(out.dtype)
923+
expert_offsets_i32 = expert_offsets.to(torch.int32)
924+
tokens_per_expert = expert_offsets_i32[1:] - expert_offsets_i32[:-1]
925+
bias_expanded = torch.repeat_interleave(self.bias, tokens_per_expert, dim=0)
926+
out = out + bias_expanded.to(out.dtype)
907927

908928
return out.to(inp_dtype)
909929

910-
def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
930+
def _forward_batched(
931+
self,
932+
x: torch.Tensor,
933+
expert_offsets: torch.Tensor,
934+
*,
935+
token_ids: Optional[torch.Tensor] = None,
936+
gating_weights: Optional[torch.Tensor] = None,
937+
num_dest_tokens: Optional[int] = None,
938+
) -> torch.Tensor:
911939
"""Batched GEMM path (SM_100 datacenter Blackwell).
912940
913941
Pipeline with init/run split for CUDA graph compatibility:
914-
1. abs().max() — compute tensor scale (device-side)
915-
2. quantize_nvfp4_raw — quantize all tokens in one launch
916-
3. cmoe_scatter_nvfp4 — FP4 data → persistent padded buffer
942+
1. abs().max() — compute tensor scale (device-side)
943+
2. quantize_nvfp4_raw — quantize all tokens in one launch
944+
3. cmoe_scatter_nvfp4 — FP4 data → persistent padded buffer
917945
4. scale_to_blocked_batched — scales → persistent swizzled buffer
918-
5. batched GEMM run() — init-if-needed, then just run(stream)
919-
6. moe_gather_bf16 — padded per-expert → concatenated output
946+
5. batched GEMM run() — init-if-needed, then just run(stream)
947+
6. gather — weighted or unweighted depending on args
920948
921-
All persistent buffers (A, SFA, D, alpha) are cached in the module
922-
so their addresses are stable for the CUTLASS init/run split.
949+
All persistent buffers (A, SFA, D, alpha, gather workspace) are cached
950+
in the module so their addresses are stable for the CUTLASS init/run split.
951+
No .item() GPU-CPU sync on the common (decode) path.
923952
"""
924953
import ctypes as ct
925954

@@ -928,21 +957,29 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
928957
from bitsandbytes.functional import (
929958
_get_tensor_stream,
930959
get_ptr,
931-
moe_gather_bf16,
932960
quantize_nvfp4_raw,
933-
scale_to_blocked_batched,
934961
)
935962

936963
inp_dtype = x.dtype
937964
N, K = self.output_features, self.input_features
938965
num_experts = self.num_experts
966+
total_tokens = x.shape[0] # CPU int, no GPU sync
967+
use_weighted = token_ids is not None and gating_weights is not None
968+
dev = x.device
939969

940970
expert_offsets_i32 = expert_offsets.to(torch.int32)
941971
tokens_per_expert = expert_offsets_i32[1:] - expert_offsets_i32[:-1]
942-
# .item() for shape computation — needed for tensor allocation
943-
raw_max_M = tokens_per_expert.max().item()
944-
max_M = ((raw_max_M + 127) // 128) * 128
945-
total_tokens = expert_offsets_i32[-1].item()
972+
973+
# Determine max_M without GPU sync on common path.
974+
# If cache exists and allocated_max_M >= total_tokens (upper bound on
975+
# any single expert's count), the buffers are guaranteed sufficient.
976+
if (hasattr(self, "_batched_cache")
977+
and total_tokens <= self._batched_cache.get("allocated_max_M", 0)):
978+
max_M = self._batched_cache["allocated_max_M"]
979+
else:
980+
# First call or total_tokens exceeds allocation: sync once
981+
raw_max_M = tokens_per_expert.max().item()
982+
max_M = ((raw_max_M + 127) // 128) * 128
946983

947984
x_2d = x.reshape(-1, K).to(torch.bfloat16).contiguous()
948985

@@ -956,7 +993,6 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
956993
# 3. Ensure persistent cached buffers exist (stable pointers for init/run)
957994
cache_key = (max_M, N, K, num_experts)
958995
if not hasattr(self, "_batched_cache") or self._batched_cache.get("key") != cache_key:
959-
dev = x.device
960996
W = K // 16
961997
n_col_blocks = (W + 3) // 4
962998
n_row_blocks = (max_M + 127) // 128
@@ -965,13 +1001,32 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
9651001

9661002
self._batched_cache = {
9671003
"key": cache_key,
1004+
"allocated_max_M": max_M,
9681005
"A_batched": torch.empty(num_experts * max_M * (K // 2), dtype=torch.uint8, device=dev),
9691006
"SFA_batched": torch.zeros(sfa_total, dtype=torch.uint8, device=dev),
9701007
"D_out": torch.empty(num_experts * max_M, N, dtype=torch.bfloat16, device=dev),
9711008
"alpha_dev": torch.empty(1, dtype=torch.float32, device=dev),
1009+
# Pre-computed constants for scale swizzle
1010+
"sfa_per_expert": sfa_per_expert,
1011+
"n_row_blocks": n_row_blocks,
1012+
"W": W,
1013+
"expert_out_offsets": torch.arange(
1014+
num_experts, dtype=torch.int32, device=dev,
1015+
) * sfa_per_expert,
9721016
}
9731017
cache = self._batched_cache
9741018

1019+
# Ensure weighted gather buffers exist if needed
1020+
if use_weighted and num_dest_tokens is not None:
1021+
if cache.get("gather_num_dest") != num_dest_tokens:
1022+
cache["gather_workspace"] = torch.empty(
1023+
num_dest_tokens * N, dtype=torch.float32, device=dev,
1024+
)
1025+
cache["gather_output"] = torch.empty(
1026+
num_dest_tokens, N, dtype=torch.bfloat16, device=dev,
1027+
)
1028+
cache["gather_num_dest"] = num_dest_tokens
1029+
9751030
stream = _get_tensor_stream(x_2d)
9761031

9771032
# 4. Scatter FP4 data into persistent buffer
@@ -986,29 +1041,16 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
9861041
)
9871042

9881043
# 5. Swizzle scales per-expert into persistent buffer
989-
W = K // 16
990-
n_col_blocks = (W + 3) // 4
991-
n_row_blocks = (max_M + 127) // 128
992-
sfa_per_expert = n_row_blocks * n_col_blocks * 512
993-
sfa_total = num_experts * sfa_per_expert
994-
995-
expert_row_offsets = expert_offsets_i32[:-1]
996-
expert_M_dev = tokens_per_expert.to(torch.int32)
997-
expert_out_offsets = torch.arange(
998-
num_experts, dtype=torch.int32, device=x.device,
999-
) * sfa_per_expert
1000-
1001-
# Zero persistent SFA buffer, then swizzle into it
10021044
cache["SFA_batched"].zero_()
10031045
lib.cscale_to_blocked_batched(
10041046
get_ptr(scales_all),
10051047
get_ptr(cache["SFA_batched"]),
1006-
get_ptr(expert_row_offsets),
1007-
get_ptr(expert_M_dev),
1008-
get_ptr(expert_out_offsets),
1009-
ct.c_int(W),
1048+
get_ptr(expert_offsets_i32[:-1]),
1049+
get_ptr(tokens_per_expert),
1050+
get_ptr(cache["expert_out_offsets"]),
1051+
ct.c_int(cache["W"]),
10101052
ct.c_int(num_experts),
1011-
ct.c_int(n_row_blocks),
1053+
ct.c_int(cache["n_row_blocks"]),
10121054
stream,
10131055
)
10141056

@@ -1028,18 +1070,50 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
10281070
max_M, N, K, num_experts,
10291071
)
10301072

1031-
# 8. Gather: padded per-expert BF16 → concatenated output
1032-
out = moe_gather_bf16(
1033-
cache["D_out"].view(-1), expert_offsets_i32, max_M, N, num_experts, total_tokens,
1034-
)
1035-
out = out.view(total_tokens, N)
1036-
1073+
# 8. Add bias to GEMM output (before gather, included in weighted sum)
10371074
if self.bias is not None:
1038-
for i in range(num_experts):
1039-
start = expert_offsets_i32[i].item()
1040-
end = expert_offsets_i32[i + 1].item()
1041-
if end > start:
1042-
out[start:end] = out[start:end] + self.bias[i].to(out.dtype)
1075+
D_out_3d = cache["D_out"].view(num_experts, max_M, N)
1076+
D_out_3d += self.bias.unsqueeze(1).to(D_out_3d.dtype)
1077+
1078+
# 9. Gather: padded per-expert → output
1079+
if use_weighted and num_dest_tokens is not None:
1080+
# Derive expert_ids and slot_ids from expert_offsets (all on GPU)
1081+
expert_ids = torch.repeat_interleave(
1082+
torch.arange(num_experts, device=dev, dtype=torch.int32),
1083+
tokens_per_expert,
1084+
)
1085+
starts_expanded = torch.repeat_interleave(
1086+
expert_offsets_i32[:-1], tokens_per_expert,
1087+
)
1088+
slot_ids = (
1089+
torch.arange(total_tokens, device=dev, dtype=torch.int32)
1090+
- starts_expanded
1091+
)
1092+
1093+
# Fused weighted gather: gather + weight + FP32 accumulate + BF16 convert
1094+
lib.cmoe_weighted_gather_bf16(
1095+
get_ptr(cache["D_out"]),
1096+
get_ptr(cache["gather_output"]),
1097+
get_ptr(cache["gather_workspace"]),
1098+
get_ptr(token_ids.to(torch.int32)),
1099+
get_ptr(expert_ids),
1100+
get_ptr(slot_ids),
1101+
get_ptr(gating_weights.to(torch.float32)),
1102+
ct.c_int(total_tokens),
1103+
ct.c_int(num_dest_tokens),
1104+
ct.c_int(max_M),
1105+
ct.c_int(N),
1106+
stream,
1107+
)
1108+
out = cache["gather_output"]
1109+
else:
1110+
# Unweighted gather (backwards compatible path)
1111+
from bitsandbytes.functional import moe_gather_bf16
1112+
out = moe_gather_bf16(
1113+
cache["D_out"].view(-1), expert_offsets_i32,
1114+
max_M, N, num_experts, total_tokens,
1115+
)
1116+
out = out.view(total_tokens, N)
10431117

10441118
return out.to(inp_dtype)
10451119

0 commit comments

Comments
 (0)