Skip to content

Commit ad65089

Browse files
TimDettmersclaude
andcommitted
Add Python bindings and tests for batched NVFP4 MoE GEMM (SM120)
Wire up the CUTLASS batched MoE kernel (fixed-padding, init/run split) through the torch library op system. Replaces the old variable-offset grouped SM120 path with NotImplementedError directing to the new API. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ac36026 commit ad65089

File tree

4 files changed

+444
-43
lines changed

4 files changed

+444
-43
lines changed

bitsandbytes/_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,36 @@ def _(
559559
# total_tokens = number of rows in A_concat = A_concat.numel() / (K/2)
560560
total_tokens = A_concat.numel() // (K // 2)
561561
return torch.empty(total_tokens, N, dtype=torch.bfloat16, device=A_concat.device)
562+
563+
564+
# Batched NVFP4 GEMM for MoE inference (fixed-padding, CUDA-graph-compatible)
565+
# All experts compute max_M rows; padded rows produce ignored output.
566+
# A_batched: [num_experts * max_M * K/2] packed FP4 activations (flat)
567+
# B_all: [num_experts * N * K/2] packed FP4 weights (flat)
568+
# SFA_batched: swizzled activation scales (CUTLASS block-scaled layout)
569+
# SFB_all: swizzled weight scales (CUTLASS block-scaled layout)
570+
# alpha: [1] float32 device tensor = A_tensor_scale * B_tensor_scale
571+
torch.library.define(
572+
"bitsandbytes::gemm_nvfp4_batched_moe",
573+
"(Tensor A_batched, Tensor B_all, Tensor SFA_batched, Tensor SFB_all, "
574+
"Tensor alpha, int max_M, int N, int K, int num_experts) -> Tensor",
575+
)
576+
577+
578+
@register_fake("bitsandbytes::gemm_nvfp4_batched_moe")
579+
def _(
580+
A_batched: torch.Tensor,
581+
B_all: torch.Tensor,
582+
SFA_batched: torch.Tensor,
583+
SFB_all: torch.Tensor,
584+
alpha: torch.Tensor,
585+
max_M: int,
586+
N: int,
587+
K: int,
588+
num_experts: int,
589+
) -> torch.Tensor:
590+
torch._check_is_size(max_M)
591+
torch._check_is_size(N)
592+
torch._check_is_size(K)
593+
torch._check_is_size(num_experts)
594+
return torch.empty(num_experts * max_M, N, dtype=torch.bfloat16, device=A_batched.device)

bitsandbytes/backends/cuda/ops.py

Lines changed: 115 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,43 +1111,100 @@ def _(
11111111
)
11121112

11131113

1114-
# Grouped NVFP4 GEMM for MoE inference (SM_120+)
1114+
# Batched NVFP4 GEMM for MoE inference (SM_120, CUDA-graph-compatible)
11151115
#
1116-
# Fuses all expert GEMMs into a single kernel launch using expert-offset
1117-
# work decomposition with binary search. Uses swizzled (block-scaled) scales.
1118-
# CUDA-graph-safe: no dynamic allocations.
1119-
def _gemm_nvfp4_grouped_raw(
1120-
A_concat: torch.Tensor,
1116+
# Fixed-padding approach: all experts compute max_M rows. Padded rows produce
1117+
# ignored output that the caller discards.
1118+
# Uses CUTLASS batched GEMM with init/run split for CUDA graph support.
1119+
#
1120+
# Cache: stores the last (N, K, max_M, num_experts) init configuration.
1121+
# On cache hit, skips init and just calls run. On miss, re-inits.
1122+
_batched_moe_sm120_cache: Optional[dict] = None
1123+
_batched_moe_restype_set = False
1124+
1125+
1126+
def _ensure_batched_moe_restypes():
1127+
global _batched_moe_restype_set
1128+
if not _batched_moe_restype_set:
1129+
lib.cgemm_nvfp4_moe_sm120_sfa_size.restype = ct.c_size_t
1130+
lib.cgemm_nvfp4_moe_sm120_sfb_size.restype = ct.c_size_t
1131+
lib.cgemm_nvfp4_moe_sm120_workspace_size.restype = ct.c_size_t
1132+
lib.cgemm_nvfp4_moe_sm120_init.restype = ct.c_int
1133+
lib.cgemm_nvfp4_moe_sm120_run.restype = ct.c_int
1134+
lib.cgemm_nvfp4_moe_sm120_sfa_size_per_expert.restype = ct.c_size_t
1135+
lib.cgemm_nvfp4_moe_sm120_sfb_size_per_expert.restype = ct.c_size_t
1136+
_batched_moe_restype_set = True
1137+
1138+
1139+
def _batched_moe_sm120_init_if_needed(
1140+
A_batched: torch.Tensor,
11211141
B_all: torch.Tensor,
1122-
SFA_concat: torch.Tensor,
1142+
SFA_batched: torch.Tensor,
11231143
SFB_all: torch.Tensor,
1124-
D_concat: torch.Tensor,
1125-
expert_offsets: torch.Tensor,
1126-
cumul_m_tiles: torch.Tensor,
1144+
D_out: torch.Tensor,
1145+
alpha: torch.Tensor,
1146+
max_M: int,
1147+
N: int,
1148+
K: int,
1149+
num_experts: int,
1150+
stream: int,
1151+
) -> None:
1152+
"""Call cgemm_nvfp4_moe_sm120_init if the configuration changed, else skip."""
1153+
global _batched_moe_sm120_cache
1154+
_ensure_batched_moe_restypes()
1155+
1156+
cache_key = (N, K, max_M, num_experts)
1157+
1158+
if (_batched_moe_sm120_cache is not None
1159+
and _batched_moe_sm120_cache["key"] == cache_key):
1160+
return
1161+
1162+
ws_size = lib.cgemm_nvfp4_moe_sm120_workspace_size(
1163+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts),
1164+
)
1165+
workspace = torch.empty(max(ws_size, 1), dtype=torch.uint8, device=A_batched.device)
1166+
1167+
ret = lib.cgemm_nvfp4_moe_sm120_init(
1168+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts),
1169+
get_ptr(A_batched), get_ptr(B_all),
1170+
get_ptr(SFA_batched), get_ptr(SFB_all),
1171+
get_ptr(D_out), get_ptr(alpha),
1172+
get_ptr(workspace), ct.c_size_t(ws_size), stream,
1173+
)
1174+
if ret != 0:
1175+
raise RuntimeError(f"cgemm_nvfp4_moe_sm120_init failed with code {ret}")
1176+
1177+
_batched_moe_sm120_cache = {
1178+
"key": cache_key,
1179+
"workspace": workspace, # prevent GC
1180+
}
1181+
1182+
1183+
def _gemm_nvfp4_batched_moe_raw(
1184+
A_batched: torch.Tensor,
1185+
B_all: torch.Tensor,
1186+
SFA_batched: torch.Tensor,
1187+
SFB_all: torch.Tensor,
1188+
D_out: torch.Tensor,
1189+
alpha: torch.Tensor,
1190+
max_M: int,
11271191
N: int,
11281192
K: int,
11291193
num_experts: int,
1130-
total_tiles: int,
11311194
) -> None:
1132-
"""Raw grouped NVFP4 GEMM (BF16 output) — zero allocations, CUDA-graph-safe.
1195+
"""Raw batched MoE NVFP4 GEMM — init-if-needed then run.
11331196
1134-
All buffers must be pre-allocated. D_concat must be BF16 of shape (total_tokens, N).
1135-
expert_offsets and cumul_m_tiles must be int32 on the same device.
1197+
All buffers must be pre-allocated. D_out must be BF16 of shape (num_experts * max_M, N).
1198+
alpha must be a float32 device tensor of shape (1,) containing A_scale * B_scale.
11361199
"""
1137-
lib.cgemm_nvfp4_grouped_bf16(
1138-
get_ptr(A_concat),
1139-
get_ptr(B_all),
1140-
get_ptr(SFA_concat),
1141-
get_ptr(SFB_all),
1142-
get_ptr(D_concat),
1143-
get_ptr(expert_offsets),
1144-
get_ptr(cumul_m_tiles),
1145-
ct.c_int(N),
1146-
ct.c_int(K),
1147-
ct.c_int(num_experts),
1148-
ct.c_int(total_tiles),
1149-
_get_tensor_stream(A_concat),
1200+
stream = _get_tensor_stream(A_batched)
1201+
_batched_moe_sm120_init_if_needed(
1202+
A_batched, B_all, SFA_batched, SFB_all, D_out, alpha,
1203+
max_M, N, K, num_experts, stream,
11501204
)
1205+
ret = lib.cgemm_nvfp4_moe_sm120_run(stream)
1206+
if ret != 0:
1207+
raise RuntimeError(f"cgemm_nvfp4_moe_sm120_run failed with code {ret}")
11511208

11521209

11531210
# Cached state for grouped SM_100 GEMM
@@ -1304,23 +1361,38 @@ def _(
13041361
A_tensor_scale, B_tensor_scale, N, K, num_experts,
13051362
)
13061363

1307-
# SM_120 (consumer Blackwell): use hand-written grouped kernel
1308-
# SM_120 expects globally-swizzled SFA, so swizzle the row-major input
1309-
total_tokens = A_concat.numel() // (K // 2)
1310-
scale_W = K // 16
1311-
SFA_blocked = torch.ops.bitsandbytes.scale_to_blocked(SFA_rowmajor, total_tokens, scale_W)
1364+
# SM_120 (consumer Blackwell): deprecated grouped path.
1365+
# Use gemm_nvfp4_batched_moe (fixed-padding) instead.
1366+
raise NotImplementedError(
1367+
"SM_120 grouped (variable-offset) NVFP4 MoE GEMM has been removed. "
1368+
"Use bitsandbytes::gemm_nvfp4_batched_moe with fixed-padding instead."
1369+
)
13121370

1313-
num_n_tiles = (N + 127) // 128
13141371

1315-
with _cuda_device_of(A_concat):
1316-
D_concat = torch.empty(total_tokens, N, dtype=torch.bfloat16, device=A_concat.device)
1317-
total_tiles = cumul_m_tiles[-1].item() * num_n_tiles
1372+
@register_kernel("bitsandbytes::gemm_nvfp4_batched_moe", "cuda")
1373+
def _(
1374+
A_batched: torch.Tensor,
1375+
B_all: torch.Tensor,
1376+
SFA_batched: torch.Tensor,
1377+
SFB_all: torch.Tensor,
1378+
alpha: torch.Tensor,
1379+
max_M: int,
1380+
N: int,
1381+
K: int,
1382+
num_experts: int,
1383+
) -> torch.Tensor:
1384+
"""Batched NVFP4 GEMM for MoE: all experts compute max_M rows.
13181385
1319-
_gemm_nvfp4_grouped_raw(
1320-
A_concat, B_all, SFA_blocked, SFB_all, D_concat,
1321-
expert_offsets, cumul_m_tiles, N, K, num_experts, total_tiles,
1386+
A_batched: flat packed FP4 activations, (num_experts * max_M * K/2) bytes.
1387+
B_all: flat packed FP4 weights, (num_experts * N * K/2) bytes.
1388+
SFA_batched: pre-swizzled activation scales (CUTLASS block-scaled layout).
1389+
SFB_all: pre-swizzled weight scales (CUTLASS block-scaled layout).
1390+
alpha: float32 device tensor [1], = A_tensor_scale * B_tensor_scale.
1391+
"""
1392+
with _cuda_device_of(A_batched):
1393+
D_out = torch.empty(num_experts * max_M, N, dtype=torch.bfloat16, device=A_batched.device)
1394+
_gemm_nvfp4_batched_moe_raw(
1395+
A_batched, B_all, SFA_batched, SFB_all, D_out, alpha,
1396+
max_M, N, K, num_experts,
13221397
)
1323-
1324-
# Apply tensor scales (SM_120 kernel has no alpha epilogue)
1325-
D_concat *= A_tensor_scale * B_tensor_scale
1326-
return D_concat
1398+
return D_out

bitsandbytes/functional.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,48 @@ def gemm_nvfp4_grouped(
14141414
)
14151415

14161416

1417+
def gemm_nvfp4_batched_moe(
1418+
A_data: torch.Tensor,
1419+
A_scales: torch.Tensor,
1420+
A_tensor_scale: float,
1421+
B_data_all: torch.Tensor,
1422+
B_scales_all: torch.Tensor,
1423+
B_tensor_scale: float,
1424+
max_M: int,
1425+
N: int,
1426+
K: int,
1427+
num_experts: int,
1428+
) -> torch.Tensor:
1429+
"""Batched NVFP4 GEMM for MoE with fixed padding (CUDA-graph-compatible).
1430+
1431+
All experts compute max_M rows. Padded rows produce ignored output that the
1432+
caller discards after gathering results.
1433+
1434+
Args:
1435+
A_data: Packed FP4 activations, (num_experts, max_M, K/2) flat uint8 tensor.
1436+
A_scales: Pre-swizzled activation scales (CUTLASS block-scaled layout).
1437+
A_tensor_scale: Tensor-level scale for activations (float).
1438+
B_data_all: Packed FP4 weights, (num_experts, N, K/2) flat uint8 tensor.
1439+
B_scales_all: Pre-swizzled weight scales (CUTLASS block-scaled layout).
1440+
B_tensor_scale: Tensor-level scale for weights (float).
1441+
max_M: Maximum tokens per expert (all experts compute this many rows).
1442+
N: Output dimension per expert.
1443+
K: Input/hidden dimension per expert.
1444+
num_experts: Number of experts.
1445+
1446+
Returns:
1447+
BF16 output of shape (num_experts * max_M, N). Caller should reshape to
1448+
(num_experts, max_M, N) and slice to actual token counts per expert.
1449+
"""
1450+
alpha = torch.tensor(
1451+
[A_tensor_scale * B_tensor_scale], dtype=torch.float32, device=A_data.device,
1452+
)
1453+
return torch.ops.bitsandbytes.gemm_nvfp4_batched_moe(
1454+
A_data, B_data_all, A_scales, B_scales_all, alpha,
1455+
max_M, N, K, num_experts,
1456+
)
1457+
1458+
14171459
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
14181460
def quantize(
14191461
A: Tensor,

0 commit comments

Comments
 (0)