Skip to content

Commit 90aecac

Browse files
TimDettmersclaude
andcommitted
Update _forward_batched to use init/run split with persistent buffers
Replace the old 6-kernel pipeline with persistent buffer management for the CUTLASS init/run split. All buffers (A_batched, SFA_batched, D_out, alpha_dev) are cached in the module to ensure stable pointers across forward passes. Uses C functions directly for scatter and scale swizzle to write into pre-allocated buffers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 55a71d1 commit 90aecac

File tree

1 file changed

+86
-23
lines changed

1 file changed

+86
-23
lines changed

bitsandbytes/nn/modules.py

Lines changed: 86 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -910,18 +910,25 @@ def _forward_grouped(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
910910
def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> torch.Tensor:
911911
"""Batched GEMM path (SM_100 datacenter Blackwell).
912912
913-
6-kernel pipeline with zero .item() in the compute path:
914-
1. abs().max() — PyTorch reduction, stays on GPU
915-
2. quantize_nvfp4_raw — quantize all tokens in one launch
916-
3. moe_scatter_nvfp4 — FP4 concat → padded per-expert
917-
4. scale_to_blocked_batched — row-major → swizzled per-expert
918-
5. gemm_nvfp4_moe — batched GEMM with device-side alpha
919-
6. moe_gather_bf16 — padded per-expert → concat
913+
Pipeline with init/run split for CUDA graph compatibility:
914+
1. abs().max() — compute tensor scale (device-side)
915+
2. quantize_nvfp4_raw — quantize all tokens in one launch
916+
3. cmoe_scatter_nvfp4 — FP4 data → persistent padded buffer
917+
4. scale_to_blocked_batched — scales → persistent swizzled buffer
918+
5. batched GEMM run() — init-if-needed, then just run(stream)
919+
6. moe_gather_bf16 — padded per-expert → concatenated output
920+
921+
All persistent buffers (A, SFA, D, alpha) are cached in the module
922+
so their addresses are stable for the CUTLASS init/run split.
920923
"""
924+
import ctypes as ct
925+
926+
from bitsandbytes.backends.cuda.ops import _gemm_nvfp4_batched_moe_sm100_raw
927+
from bitsandbytes.cextension import lib
921928
from bitsandbytes.functional import (
922-
gemm_nvfp4_moe,
929+
_get_tensor_stream,
930+
get_ptr,
923931
moe_gather_bf16,
924-
moe_scatter_nvfp4,
925932
quantize_nvfp4_raw,
926933
scale_to_blocked_batched,
927934
)
@@ -946,28 +953,84 @@ def _forward_batched(self, x: torch.Tensor, expert_offsets: torch.Tensor) -> tor
946953
# 2. Quantize ALL concatenated tokens in one launch
947954
packed_all, scales_all = quantize_nvfp4_raw(x_2d, global_scale_dev)
948955

949-
# 3. Scatter: FP4 data from concatenated to padded per-expert layout
950-
packed_batched = moe_scatter_nvfp4(
951-
packed_all, expert_offsets_i32, max_M, K, num_experts,
956+
# 3. Ensure persistent cached buffers exist (stable pointers for init/run)
957+
cache_key = (max_M, N, K, num_experts)
958+
if not hasattr(self, "_batched_cache") or self._batched_cache.get("key") != cache_key:
959+
dev = x.device
960+
W = K // 16
961+
n_col_blocks = (W + 3) // 4
962+
n_row_blocks = (max_M + 127) // 128
963+
sfa_per_expert = n_row_blocks * n_col_blocks * 512
964+
sfa_total = num_experts * sfa_per_expert
965+
966+
self._batched_cache = {
967+
"key": cache_key,
968+
"A_batched": torch.empty(num_experts * max_M * (K // 2), dtype=torch.uint8, device=dev),
969+
"SFA_batched": torch.zeros(sfa_total, dtype=torch.uint8, device=dev),
970+
"D_out": torch.empty(num_experts * max_M, N, dtype=torch.bfloat16, device=dev),
971+
"alpha_dev": torch.empty(1, dtype=torch.float32, device=dev),
972+
}
973+
cache = self._batched_cache
974+
975+
stream = _get_tensor_stream(x_2d)
976+
977+
# 4. Scatter FP4 data into persistent buffer
978+
lib.cmoe_scatter_nvfp4(
979+
get_ptr(packed_all),
980+
get_ptr(cache["A_batched"]),
981+
get_ptr(expert_offsets_i32),
982+
ct.c_int(max_M),
983+
ct.c_int(K),
984+
ct.c_int(num_experts),
985+
stream,
952986
)
953987

954-
# 4. Swizzle scales: row-major → per-expert batched CUTLASS layout
955-
sfa_batched = scale_to_blocked_batched(
956-
scales_all, expert_offsets_i32, max_M, K, num_experts,
988+
# 5. Swizzle scales per-expert into persistent buffer
989+
W = K // 16
990+
n_col_blocks = (W + 3) // 4
991+
n_row_blocks = (max_M + 127) // 128
992+
sfa_per_expert = n_row_blocks * n_col_blocks * 512
993+
sfa_total = num_experts * sfa_per_expert
994+
995+
expert_row_offsets = expert_offsets_i32[:-1]
996+
expert_M_dev = tokens_per_expert.to(torch.int32)
997+
expert_out_offsets = torch.arange(
998+
num_experts, dtype=torch.int32, device=x.device,
999+
) * sfa_per_expert
1000+
1001+
# Zero persistent SFA buffer, then swizzle into it
1002+
cache["SFA_batched"].zero_()
1003+
lib.cscale_to_blocked_batched(
1004+
get_ptr(scales_all),
1005+
get_ptr(cache["SFA_batched"]),
1006+
get_ptr(expert_row_offsets),
1007+
get_ptr(expert_M_dev),
1008+
get_ptr(expert_out_offsets),
1009+
ct.c_int(W),
1010+
ct.c_int(num_experts),
1011+
ct.c_int(n_row_blocks),
1012+
stream,
9571013
)
9581014

959-
# 5. Batched GEMM with device-side alpha (no .item() sync)
960-
alpha_dev = (act_tensor_scale_dev * self.weight_tensor_scale).to(torch.float32)
961-
D = gemm_nvfp4_moe(
962-
packed_batched, sfa_batched, alpha_dev,
963-
self.weight_packed, self.weight_scales_batched,
1015+
# 6. Set alpha (device-side, no .item() sync)
1016+
cache["alpha_dev"].copy_(
1017+
(act_tensor_scale_dev * self.weight_tensor_scale).to(torch.float32).reshape(1)
1018+
)
1019+
1020+
# 7. Batched GEMM (init-if-needed, then just run(stream))
1021+
_gemm_nvfp4_batched_moe_sm100_raw(
1022+
cache["A_batched"],
1023+
self.weight_packed,
1024+
cache["SFA_batched"],
1025+
self.weight_scales_batched,
1026+
cache["D_out"],
1027+
cache["alpha_dev"],
9641028
max_M, N, K, num_experts,
9651029
)
9661030

967-
# 6. Gather: padded per-expert BF16 → concatenated output
968-
D_flat = D.view(-1).contiguous()
1031+
# 8. Gather: padded per-expert BF16 → concatenated output
9691032
out = moe_gather_bf16(
970-
D_flat, expert_offsets_i32, max_M, N, num_experts, total_tokens,
1033+
cache["D_out"].view(-1), expert_offsets_i32, max_M, N, num_experts, total_tokens,
9711034
)
9721035
out = out.view(total_tokens, N)
9731036

0 commit comments

Comments
 (0)