Skip to content

Commit 30138f8

Browse files
TimDettmersclaude
andcommitted
feat: Add M-based dispatch for NVFP4 GEMM (hand-written for M<64)
Dispatches gemm_nvfp4 between: - M < 64: hand-written kernel (mma.sync, auto split-K, BF16 output) - M >= 64: CUTLASS SM_120 GEMM (BF16 output) The hand-written kernel uses flat row-major scales and doesn't fold tensor scales into the epilogue, so they're applied after the GEMM. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 52fe6af commit 30138f8

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

bitsandbytes/functional.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,15 @@ def dequantize_nvfp4(
12471247
return out.reshape(quant_state.shape)
12481248

12491249

1250+
# Dispatch threshold: use hand-written GEMM for small M (decode), CUTLASS for large M
1251+
_GEMM_HW_M_THRESHOLD = 64
1252+
1253+
1254+
def _has_hw_gemm() -> bool:
1255+
"""Check if hand-written NVFP4 GEMM is available (SM_120+ builds only)."""
1256+
return hasattr(lib, "cgemm_nvfp4_bf16")
1257+
1258+
12501259
def gemm_nvfp4(
12511260
A_data: torch.Tensor,
12521261
A_state: NVFP4QuantState,
@@ -1255,6 +1264,10 @@ def gemm_nvfp4(
12551264
) -> torch.Tensor:
12561265
"""NVFP4 GEMM: compute A @ B^T using block-scaled FP4 inputs.
12571266
1267+
Dispatches between two kernels based on M:
1268+
- M < 64: hand-written kernel (mma.sync + auto split-K, BF16 output)
1269+
- M >= 64: CUTLASS SM_120 GEMM (BF16 output)
1270+
12581271
Args:
12591272
A_data: Packed FP4 data for A (M*K/2 bytes).
12601273
A_state: Quantization state for A (M x K).
@@ -1268,7 +1281,27 @@ def gemm_nvfp4(
12681281
K = A_state.shape[1]
12691282
N = B_state.shape[0]
12701283

1271-
# Use pre-swizzled scales for CUTLASS GEMM (computed at quantization time)
1284+
if M < _GEMM_HW_M_THRESHOLD and _has_hw_gemm() and A_data.is_cuda:
1285+
# Hand-written kernel: flat (non-swizzled) scales, BF16 output
1286+
from bitsandbytes.backends.cuda.ops import _gemm_nvfp4_hw_bf16_raw
1287+
1288+
D_out = torch.empty(M, N, dtype=torch.bfloat16, device=A_data.device)
1289+
workspace = torch.empty(M, N, dtype=torch.float32, device=A_data.device)
1290+
_gemm_nvfp4_hw_bf16_raw(
1291+
A_data,
1292+
B_data,
1293+
A_state.block_scales,
1294+
B_state.block_scales,
1295+
D_out,
1296+
workspace,
1297+
M,
1298+
N,
1299+
K,
1300+
)
1301+
# Apply tensor scales and convert to FP32 for API compatibility
1302+
return D_out.float() * (A_state.tensor_scale * B_state.tensor_scale)
1303+
1304+
# CUTLASS: pre-swizzled scales, BF16 output
12721305
A_scales = A_state.block_scales_blocked if A_state.block_scales_blocked is not None else A_state.block_scales
12731306
B_scales = B_state.block_scales_blocked if B_state.block_scales_blocked is not None else B_state.block_scales
12741307

0 commit comments

Comments
 (0)