Skip to content

Commit 54c9204

Browse files
TimDettmersclaude
andcommitted
Rewrite _forward_batched with 6-kernel pipeline, zero compute-path sync
Replaces the old Python for-loop pipeline (~19 kernel launches + .item() sync) with a streamlined 6-kernel pipeline: 1. abs().max() — stays on GPU as device tensor 2. quantize_nvfp4_raw — all tokens in one launch 3. moe_scatter_nvfp4 — FP4 concat → padded per-expert 4. scale_to_blocked_batched — row-major → swizzled scales 5. gemm_nvfp4_moe — batched GEMM with device-side alpha 6. moe_gather_bf16 — padded per-expert → concat No .item() in the compute path. Shape-related .item() for max_M/total_tokens is kept (required for tensor allocation). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b0f1d30 commit 54c9204

File tree

1 file changed

+38
-38
lines changed

1 file changed

+38
-38
lines changed

bitsandbytes/nn/modules.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -910,66 +910,66 @@ def _forward_grouped(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
910910
def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
911911
"""Batched GEMM path (SM_100 datacenter Blackwell).
912912
913-
Scatters tokens into padded (num_experts, max_M, K) layout, quantizes
914-
per-expert, runs a single batched GEMM kernel, then gathers results.
913+
6-kernel pipeline with zero .item() in the compute path:
914+
1. abs().max() — PyTorch reduction, stays on GPU
915+
2. quantize_nvfp4_raw — quantize all tokens in one launch
916+
3. moe_scatter_nvfp4 — FP4 concat → padded per-expert
917+
4. scale_to_blocked_batched — row-major → swizzled per-expert
918+
5. gemm_nvfp4_moe — batched GEMM with device-side alpha
919+
6. moe_gather_bf16 — padded per-expert → concat
915920
"""
916-
from bitsandbytes.functional import gemm_nvfp4_moe, quantize_nvfp4
921+
from bitsandbytes.functional import (
922+
gemm_nvfp4_moe,
923+
moe_gather_bf16,
924+
moe_scatter_nvfp4,
925+
quantize_nvfp4_raw,
926+
scale_to_blocked_batched,
927+
)
917928

918929
inp_dtype = x.dtype
919930
N, K = self.output_features, self.input_features
920931
num_experts = self.num_experts
921932

922933
expert_offsets_i32 = expert_offsets.to(torch.int32)
923934
tokens_per_expert = expert_offsets_i32[1:] - expert_offsets_i32[:-1]
935+
# .item() for shape computation — needed for tensor allocation
924936
raw_max_M = tokens_per_expert.max().item()
925-
# Pad to multiple of 128 for CUTLASS tile efficiency
926937
max_M = ((raw_max_M + 127) // 128) * 128
938+
total_tokens = expert_offsets_i32[-1].item()
927939

928940
x_2d = x.reshape(-1, K).to(torch.bfloat16).contiguous()
929941

930-
# Shared tensor scale across all experts (matches grouped GEMM behavior)
931-
act_tensor_scale = x_2d.abs().max().item()
932-
933-
# Quantize per-expert with shared tensor scale
934-
all_act_packed = []
935-
all_act_scales = []
936-
937-
for i in range(num_experts):
938-
start = expert_offsets_i32[i].item()
939-
end = expert_offsets_i32[i + 1].item()
940-
n_tok = end - start
941-
942-
act_padded = torch.zeros(max_M, K, dtype=torch.bfloat16, device=x.device)
943-
if n_tok > 0:
944-
act_padded[:n_tok] = x_2d[start:end]
942+
# 1. Compute tensor scale on GPU (no .item(), stays as device tensor)
943+
act_tensor_scale_dev = x_2d.abs().max()
944+
global_scale_dev = (1.0 / act_tensor_scale_dev).to(torch.float32)
945945

946-
act_packed, act_state = quantize_nvfp4(act_padded, tensor_scale=act_tensor_scale)
947-
all_act_packed.append(act_packed)
948-
all_act_scales.append(act_state.block_scales_blocked)
946+
# 2. Quantize ALL concatenated tokens in one launch
947+
packed_all, scales_all = quantize_nvfp4_raw(x_2d, global_scale_dev)
949948

950-
A_batched = torch.cat(all_act_packed)
951-
SFA_batched = torch.cat(all_act_scales)
949+
# 3. Scatter: FP4 data from concatenated to padded per-expert layout
950+
packed_batched = moe_scatter_nvfp4(
951+
packed_all, expert_offsets_i32, max_M, K, num_experts,
952+
)
952953

953-
# Run batched GEMM (alpha is a device tensor for graph safety)
954-
alpha_dev = torch.tensor(
955-
[act_tensor_scale * self.weight_tensor_scale],
956-
dtype=torch.float32, device=x.device,
954+
# 4. Swizzle scales: row-major → per-expert batched CUTLASS layout
955+
sfa_batched = scale_to_blocked_batched(
956+
scales_all, expert_offsets_i32, max_M, K, num_experts,
957957
)
958+
959+
# 5. Batched GEMM with device-side alpha (no .item() sync)
960+
alpha_dev = (act_tensor_scale_dev * self.weight_tensor_scale).to(torch.float32)
958961
D = gemm_nvfp4_moe(
959-
A_batched, SFA_batched, alpha_dev,
962+
packed_batched, sfa_batched, alpha_dev,
960963
self.weight_packed, self.weight_scales_batched,
961964
max_M, N, K, num_experts,
962965
)
963966

964-
# Gather results: D is (num_experts, max_M, N)
965-
total_tokens = expert_offsets_i32[-1].item()
966-
out = torch.empty(total_tokens, N, dtype=torch.bfloat16, device=x.device)
967-
for i in range(num_experts):
968-
start = expert_offsets_i32[i].item()
969-
end = expert_offsets_i32[i + 1].item()
970-
n_tok = end - start
971-
if n_tok > 0:
972-
out[start:end] = D[i, :n_tok]
967+
# 6. Gather: padded per-expert BF16 → concatenated output
968+
D_flat = D.view(-1).contiguous()
969+
out = moe_gather_bf16(
970+
D_flat, expert_offsets_i32, max_M, N, num_experts, total_tokens,
971+
)
972+
out = out.view(total_tokens, N)
973973

974974
if self.bias is not None:
975975
for i in range(num_experts):

0 commit comments

Comments
 (0)