Skip to content

Commit cec86d7

Browse files
TimDettmersclaude
andcommitted
Remove dead SM_100 grouped GEMM code, superseded by batched MoE GEMM
The grouped GEMM SM_100 path was never reachable in practice — _forward_batched routes SM_100 to the batched kernel, and the fused variant had linking issues. Removes the C source, CMake entry, and Python wiring (cache, dispatch, SM_100 branch). SM_120 grouped kernel path is untouched. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c872065 commit cec86d7

3 files changed

Lines changed: 0 additions & 706 deletions

File tree

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ if(BUILD_CUDA)
300300
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.8" AND EXISTS "${CMAKE_SOURCE_DIR}/third_party/cutlass/include")
301301
set(_NVFP4_SM100_SOURCES
302302
csrc/qutlass/gemm_nvfp4_sm100.cu
303-
csrc/qutlass/gemm_nvfp4_grouped_sm100.cu
304303
csrc/qutlass/gemm_nvfp4_moe_sm100.cu
305304
)
306305

bitsandbytes/backends/cuda/ops.py

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,129 +1150,6 @@ def _gemm_nvfp4_grouped_raw(
11501150
)
11511151

11521152

1153-
# Cached state for grouped SM_100 GEMM
1154-
_grouped_restype_set = False
1155-
1156-
# Cached buffers for the fused C dispatch (keyed by (N, K, num_experts),
1157-
# sized for worst-case routing so the cache always hits after first call)
1158-
_grouped_fused_cache: Optional[dict] = None
1159-
1160-
1161-
def _get_fused_buffers(
1162-
total_tokens: int, N: int, K: int, num_experts: int, device: torch.device,
1163-
) -> dict:
1164-
"""Get or grow cached device buffers for the fused C dispatch.
1165-
1166-
Buffers are sized for worst-case token routing (all tokens to one expert),
1167-
keyed on (N, K, num_experts). Grows if total_tokens exceeds the cached size.
1168-
"""
1169-
global _grouped_fused_cache, _grouped_restype_set
1170-
1171-
if not _grouped_restype_set:
1172-
lib.cgemm_nvfp4_grouped_sm100_meta_size.restype = ct.c_size_t
1173-
lib.cgemm_nvfp4_grouped_sm100_workspace_size.restype = ct.c_size_t
1174-
_grouped_restype_set = True
1175-
1176-
if (_grouped_fused_cache is not None
1177-
and _grouped_fused_cache["N"] == N
1178-
and _grouped_fused_cache["K"] == K
1179-
and _grouped_fused_cache["num_experts"] == num_experts
1180-
and _grouped_fused_cache["max_tokens"] >= total_tokens):
1181-
return _grouped_fused_cache
1182-
1183-
scale_W = K // 16
1184-
n_col_blocks = (scale_W + 3) // 4
1185-
1186-
# Worst-case SFA output: each expert adds at most 1 extra 128-row block
1187-
max_row_blocks = (total_tokens + 127) // 128 + num_experts
1188-
sfa_out_bytes = max_row_blocks * n_col_blocks * 512
1189-
1190-
sfa_swizzle_out = torch.empty(max(sfa_out_bytes, 1), dtype=torch.uint8, device=device)
1191-
sfa_swizzle_meta = torch.empty(3 * num_experts * 4, dtype=torch.uint8, device=device)
1192-
1193-
meta_size = lib.cgemm_nvfp4_grouped_sm100_meta_size(ct.c_int(num_experts))
1194-
gemm_meta_buf = torch.empty(meta_size, dtype=torch.uint8, device=device)
1195-
1196-
# Worst-case workspace: all tokens routed to a single expert
1197-
M_arr = (ct.c_int * num_experts)(*([0] * num_experts))
1198-
M_arr[0] = total_tokens
1199-
ws_size = lib.cgemm_nvfp4_grouped_sm100_workspace_size(
1200-
M_arr, ct.c_int(N), ct.c_int(K), ct.c_int(num_experts),
1201-
)
1202-
workspace_buf = torch.empty(max(ws_size, 1), dtype=torch.uint8, device=device)
1203-
1204-
_grouped_fused_cache = {
1205-
"N": N, "K": K, "num_experts": num_experts, "max_tokens": total_tokens,
1206-
"sfa_swizzle_out": sfa_swizzle_out,
1207-
"sfa_swizzle_meta": sfa_swizzle_meta,
1208-
"gemm_meta_buf": gemm_meta_buf,
1209-
"workspace_buf": workspace_buf,
1210-
"ws_size": ws_size,
1211-
}
1212-
return _grouped_fused_cache
1213-
1214-
1215-
def _gemm_nvfp4_grouped_sm100(
1216-
A_concat: torch.Tensor,
1217-
B_all: torch.Tensor,
1218-
SFA_rowmajor: torch.Tensor,
1219-
SFB_all: torch.Tensor,
1220-
offsets_host: tuple[int, ...],
1221-
A_tensor_scale: float,
1222-
B_tensor_scale: float,
1223-
N: int,
1224-
K: int,
1225-
num_experts: int,
1226-
) -> torch.Tensor:
1227-
"""SM_100 grouped NVFP4 GEMM using fused C dispatch.
1228-
1229-
Single ctypes call handles SFA swizzle + CUTLASS grouped GEMM.
1230-
All metadata computation and pointer building happens in C.
1231-
Python only allocates output and passes pre-cached buffers.
1232-
1233-
SFB_all is already per-expert swizzled (each expert was independently
1234-
quantized by quantize_nvfp4, which swizzles each expert's scales
1235-
separately). No conversion needed.
1236-
1237-
offsets_host: host-side tuple of cumulative token offsets (num_experts + 1 ints).
1238-
"""
1239-
device = A_concat.device
1240-
total_tokens = offsets_host[-1]
1241-
1242-
# Get or grow cached buffers (keyed on N, K, num_experts — always hits
1243-
# after first call unless total_tokens grows)
1244-
buf = _get_fused_buffers(total_tokens, N, K, num_experts, device)
1245-
1246-
# Output (BF16 — CUTLASS accumulates in FP32, epilogue outputs BF16)
1247-
D_concat = torch.empty(total_tokens, N, dtype=torch.bfloat16, device=device)
1248-
1249-
# Build host offsets ctypes array (per-call, ~1μs for 9 ints)
1250-
host_offsets_arr = (ct.c_int * (num_experts + 1))(*offsets_host)
1251-
1252-
# Single fused C call: SFA swizzle + metadata build + GEMM launch
1253-
# SFB_all is passed directly — already per-expert swizzled from quantize_nvfp4
1254-
lib.cgemm_nvfp4_grouped_sm100_fused(
1255-
get_ptr(A_concat),
1256-
get_ptr(B_all),
1257-
get_ptr(SFA_rowmajor),
1258-
get_ptr(SFB_all),
1259-
get_ptr(D_concat),
1260-
host_offsets_arr,
1261-
ct.c_int(N),
1262-
ct.c_int(K),
1263-
ct.c_int(num_experts),
1264-
ct.c_float(A_tensor_scale * B_tensor_scale),
1265-
get_ptr(buf["sfa_swizzle_out"]),
1266-
get_ptr(buf["sfa_swizzle_meta"]),
1267-
get_ptr(buf["gemm_meta_buf"]),
1268-
get_ptr(buf["workspace_buf"]),
1269-
ct.c_size_t(buf["ws_size"]),
1270-
_get_tensor_stream(A_concat),
1271-
)
1272-
1273-
return D_concat
1274-
1275-
12761153
@register_kernel("bitsandbytes::gemm_nvfp4_grouped", "cuda")
12771154
def _(
12781155
A_concat: torch.Tensor,
@@ -1293,17 +1170,6 @@ def _(
12931170
SFB_all: per-expert swizzled weight scales (each expert independently swizzled
12941171
by quantize_nvfp4, then concatenated).
12951172
"""
1296-
# SM_100 (datacenter Blackwell): use CUTLASS grouped GEMM
1297-
major, _ = torch.cuda.get_device_capability(A_concat.device)
1298-
if major == 10 and hasattr(lib, "cgemm_nvfp4_grouped_cutlass_sm100"):
1299-
# Convert device offsets to host tuple (cheap for small arrays,
1300-
# but callers should migrate to passing host offsets directly)
1301-
offsets_host = tuple(expert_offsets.tolist())
1302-
return _gemm_nvfp4_grouped_sm100(
1303-
A_concat, B_all, SFA_rowmajor, SFB_all, offsets_host,
1304-
A_tensor_scale, B_tensor_scale, N, K, num_experts,
1305-
)
1306-
13071173
# SM_120 (consumer Blackwell): use hand-written grouped kernel
13081174
# SM_120 expects globally-swizzled SFA, so swizzle the row-major input
13091175
total_tokens = A_concat.numel() // (K // 2)

0 commit comments

Comments
 (0)