Skip to content

Commit 55a71d1

Browse files
TimDettmersclaude
andcommitted
Add weighted gather op and restructure MoE GEMM Python interface
- Add moe_weighted_gather_bf16 torch op schema and kernel registration - Add _batched_moe_sm100_init_if_needed() with dimension-based caching - Add _gemm_nvfp4_batched_moe_sm100_raw() init+run wrapper - Update gemm_nvfp4_moe kernel to use new init signature with data ptrs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 95ec556 commit 55a71d1

File tree

2 files changed

+138
-31
lines changed

2 files changed

+138
-31
lines changed

bitsandbytes/_ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,3 +672,30 @@ def _(
672672
torch._check_is_size(K)
673673
torch._check_is_size(num_experts)
674674
return torch.empty(num_experts, max_M, N, dtype=torch.bfloat16, device=A_batched.device)
675+
676+
677+
# MoE weighted gather: fused gather + scale by gating weight + FP32 accumulate + BF16 convert.
678+
# Two-phase: atomicAdd into FP32 workspace, then convert to BF16.
679+
# workspace_fp32 is a caller-managed scratch buffer (persistent for CUDA graphs).
680+
torch.library.define(
681+
"bitsandbytes::moe_weighted_gather_bf16",
682+
"(Tensor D_batched, Tensor output_bf16, Tensor workspace_fp32, "
683+
"Tensor token_ids, Tensor expert_ids, Tensor slot_ids, Tensor weights, "
684+
"int num_tokens, int max_M, int N) -> Tensor",
685+
)
686+
687+
688+
@register_fake("bitsandbytes::moe_weighted_gather_bf16")
689+
def _(
690+
D_batched: torch.Tensor,
691+
output_bf16: torch.Tensor,
692+
workspace_fp32: torch.Tensor,
693+
token_ids: torch.Tensor,
694+
expert_ids: torch.Tensor,
695+
slot_ids: torch.Tensor,
696+
weights: torch.Tensor,
697+
num_tokens: int,
698+
max_M: int,
699+
N: int,
700+
) -> torch.Tensor:
701+
return output_bf16

bitsandbytes/backends/cuda/ops.py

Lines changed: 111 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def _(
13301330

13311331
# Cached state for batched SM_100 MoE GEMM
13321332
_moe_batched_restype_set = False
1333-
_moe_batched_cache: Optional[dict] = None
1333+
_moe_batched_sm100_cache: Optional[dict] = None
13341334

13351335

13361336
def _ensure_moe_batched_restype():
@@ -1346,50 +1346,130 @@ def _ensure_moe_batched_restype():
13461346
_moe_batched_restype_set = True
13471347

13481348

1349-
@register_kernel("bitsandbytes::gemm_nvfp4_moe", "cuda")
1350-
def _(
1349+
def _batched_moe_sm100_init_if_needed(
13511350
A_batched: torch.Tensor,
1352-
B_batched: torch.Tensor,
1353-
SFA: torch.Tensor,
1354-
SFB: torch.Tensor,
1351+
B_all: torch.Tensor,
1352+
SFA_batched: torch.Tensor,
1353+
SFB_all: torch.Tensor,
1354+
D_out: torch.Tensor,
13551355
alpha: torch.Tensor,
13561356
max_M: int,
13571357
N: int,
13581358
K: int,
13591359
num_experts: int,
1360-
) -> torch.Tensor:
1361-
global _moe_batched_cache
1360+
stream: int,
1361+
) -> None:
1362+
"""Call cgemm_nvfp4_moe_sm100_init if the configuration changed, else skip."""
1363+
global _moe_batched_sm100_cache
13621364
_ensure_moe_batched_restype()
13631365

1364-
key = (max_M, N, K, num_experts)
1365-
if _moe_batched_cache is None or _moe_batched_cache["key"] != key:
1366-
ws_size = lib.cgemm_nvfp4_moe_sm100_workspace_size(
1367-
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts),
1368-
)
1369-
workspace = torch.empty(max(ws_size, 1), dtype=torch.uint8, device=A_batched.device)
1366+
cache_key = (N, K, max_M, num_experts)
13701367

1371-
ret = lib.cgemm_nvfp4_moe_sm100_init(
1372-
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts),
1373-
get_ptr(workspace), ct.c_size_t(ws_size),
1374-
)
1375-
if ret != 0:
1376-
raise RuntimeError(f"cgemm_nvfp4_moe_sm100_init failed: {ret}")
1368+
if (_moe_batched_sm100_cache is not None
1369+
and _moe_batched_sm100_cache["key"] == cache_key):
1370+
return
13771371

1378-
_moe_batched_cache = {"key": key, "workspace": workspace}
1372+
ws_size = lib.cgemm_nvfp4_moe_sm100_workspace_size(
1373+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts),
1374+
)
1375+
workspace = torch.empty(max(ws_size, 1), dtype=torch.uint8, device=A_batched.device)
1376+
1377+
ret = lib.cgemm_nvfp4_moe_sm100_init(
1378+
ct.c_int(N), ct.c_int(max_M), ct.c_int(K), ct.c_int(num_experts),
1379+
get_ptr(A_batched), get_ptr(B_all),
1380+
get_ptr(SFA_batched), get_ptr(SFB_all),
1381+
get_ptr(D_out), get_ptr(alpha),
1382+
get_ptr(workspace), ct.c_size_t(ws_size), stream,
1383+
)
1384+
if ret != 0:
1385+
raise RuntimeError(f"cgemm_nvfp4_moe_sm100_init failed with code {ret}")
13791386

1380-
# Ensure alpha is a float32 device tensor
1381-
alpha_dev = alpha.to(dtype=torch.float32, device=A_batched.device).contiguous()
1387+
_moe_batched_sm100_cache = {
1388+
"key": cache_key,
1389+
"workspace": workspace, # prevent GC
1390+
}
13821391

1383-
D_out = torch.empty(num_experts * max_M * N, dtype=torch.bfloat16, device=A_batched.device)
13841392

1385-
ret = lib.cgemm_nvfp4_moe_sm100_run(
1386-
get_ptr(A_batched), get_ptr(B_batched),
1387-
get_ptr(SFA), get_ptr(SFB),
1388-
get_ptr(D_out),
1389-
get_ptr(alpha_dev),
1390-
_get_tensor_stream(A_batched),
1393+
def _gemm_nvfp4_batched_moe_sm100_raw(
1394+
A_batched: torch.Tensor,
1395+
B_all: torch.Tensor,
1396+
SFA_batched: torch.Tensor,
1397+
SFB_all: torch.Tensor,
1398+
D_out: torch.Tensor,
1399+
alpha: torch.Tensor,
1400+
max_M: int,
1401+
N: int,
1402+
K: int,
1403+
num_experts: int,
1404+
) -> None:
1405+
"""Raw batched MoE NVFP4 GEMM — init-if-needed then run.
1406+
1407+
All buffers must be pre-allocated. D_out must be BF16 of shape (num_experts * max_M, N).
1408+
alpha must be a float32 device tensor of shape (1,) containing A_scale * B_scale.
1409+
"""
1410+
stream = _get_tensor_stream(A_batched)
1411+
_batched_moe_sm100_init_if_needed(
1412+
A_batched, B_all, SFA_batched, SFB_all, D_out, alpha,
1413+
max_M, N, K, num_experts, stream,
13911414
)
1415+
ret = lib.cgemm_nvfp4_moe_sm100_run(stream)
13921416
if ret != 0:
1393-
raise RuntimeError(f"cgemm_nvfp4_moe_sm100_run failed: {ret}")
1417+
raise RuntimeError(f"cgemm_nvfp4_moe_sm100_run failed with code {ret}")
1418+
13941419

1420+
@register_kernel("bitsandbytes::gemm_nvfp4_moe", "cuda")
1421+
def _(
1422+
A_batched: torch.Tensor,
1423+
B_batched: torch.Tensor,
1424+
SFA: torch.Tensor,
1425+
SFB: torch.Tensor,
1426+
alpha: torch.Tensor,
1427+
max_M: int,
1428+
N: int,
1429+
K: int,
1430+
num_experts: int,
1431+
) -> torch.Tensor:
1432+
with _cuda_device_of(A_batched):
1433+
D_out = torch.empty(num_experts * max_M, N, dtype=torch.bfloat16, device=A_batched.device)
1434+
_gemm_nvfp4_batched_moe_sm100_raw(
1435+
A_batched, B_batched, SFA, SFB, D_out, alpha,
1436+
max_M, N, K, num_experts,
1437+
)
13951438
return D_out.view(num_experts, max_M, N)
1439+
1440+
1441+
@register_kernel("bitsandbytes::moe_weighted_gather_bf16", "cuda")
1442+
def _(
1443+
D_batched: torch.Tensor,
1444+
output_bf16: torch.Tensor,
1445+
workspace_fp32: torch.Tensor,
1446+
token_ids: torch.Tensor,
1447+
expert_ids: torch.Tensor,
1448+
slot_ids: torch.Tensor,
1449+
weights: torch.Tensor,
1450+
num_tokens: int,
1451+
max_M: int,
1452+
N: int,
1453+
) -> torch.Tensor:
1454+
"""Fused gather + weight + FP32 accumulate + BF16 convert.
1455+
1456+
Internally launches: memset(workspace) -> atomicAdd gather -> FP32->BF16 convert.
1457+
All three operations on the same stream, capturable in a CUDA graph.
1458+
"""
1459+
total_assignments = token_ids.shape[0]
1460+
with _cuda_device_of(D_batched):
1461+
lib.cmoe_weighted_gather_bf16(
1462+
get_ptr(D_batched),
1463+
get_ptr(output_bf16),
1464+
get_ptr(workspace_fp32),
1465+
get_ptr(token_ids),
1466+
get_ptr(expert_ids),
1467+
get_ptr(slot_ids),
1468+
get_ptr(weights),
1469+
ct.c_int(total_assignments),
1470+
ct.c_int(num_tokens),
1471+
ct.c_int(max_M),
1472+
ct.c_int(N),
1473+
_get_tensor_stream(D_batched),
1474+
)
1475+
return output_bf16

0 commit comments

Comments
 (0)