Skip to content

Commit 240d9af

Browse files
TimDettmersclaude
andcommitted
feat: Add _raw wrappers for hand-written NVFP4 GEMM kernel
Exposes _gemm_nvfp4_hw_raw and _gemm_nvfp4_hw_splitk_raw for CUDA-graph-safe benchmarking. These are zero-allocation wrappers around cgemm_nvfp4 and cgemm_nvfp4_splitk that retrieve the stream at call time (needed for graph capture). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c8e31ed commit 240d9af

File tree

1 file changed

+60
-0
lines changed
  • bitsandbytes/backends/cuda

1 file changed

+60
-0
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,66 @@ def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
10461046
return out
10471047

10481048

1049+
# Hand-written NVFP4 GEMM (SM_120+)
1050+
#
1051+
# Uses mma.sync.aligned.block_scale instructions for small-M decode.
1052+
# Expects flat (non-swizzled) row-major scales. Output is FP32.
1053+
# Uses automatic split-K when tile count is low relative to SM count.
1054+
def _gemm_nvfp4_hw_raw(
1055+
A_packed: torch.Tensor,
1056+
B_packed: torch.Tensor,
1057+
A_scales: torch.Tensor,
1058+
B_scales: torch.Tensor,
1059+
D_out: torch.Tensor,
1060+
M: int,
1061+
N: int,
1062+
K: int,
1063+
) -> None:
1064+
"""Raw hand-written NVFP4 GEMM — zero allocations, CUDA-graph-safe.
1065+
1066+
All buffers must be pre-allocated. D_out must be FP32 of shape (M, N).
1067+
Scales are flat row-major (not swizzled). Uses auto split-K internally
1068+
with cudaMemsetAsync (graph-capturable).
1069+
"""
1070+
lib.cgemm_nvfp4(
1071+
get_ptr(A_packed),
1072+
get_ptr(B_packed),
1073+
get_ptr(A_scales),
1074+
get_ptr(B_scales),
1075+
get_ptr(D_out),
1076+
ct.c_int(M),
1077+
ct.c_int(N),
1078+
ct.c_int(K),
1079+
_get_tensor_stream(A_packed),
1080+
)
1081+
1082+
1083+
def _gemm_nvfp4_hw_splitk_raw(
1084+
A_packed: torch.Tensor,
1085+
B_packed: torch.Tensor,
1086+
A_scales: torch.Tensor,
1087+
B_scales: torch.Tensor,
1088+
D_out: torch.Tensor,
1089+
M: int,
1090+
N: int,
1091+
K: int,
1092+
split_k: int,
1093+
) -> None:
1094+
"""Raw hand-written NVFP4 GEMM with explicit split-K — CUDA-graph-safe."""
1095+
lib.cgemm_nvfp4_splitk(
1096+
get_ptr(A_packed),
1097+
get_ptr(B_packed),
1098+
get_ptr(A_scales),
1099+
get_ptr(B_scales),
1100+
get_ptr(D_out),
1101+
ct.c_int(M),
1102+
ct.c_int(N),
1103+
ct.c_int(K),
1104+
ct.c_int(split_k),
1105+
_get_tensor_stream(A_packed),
1106+
)
1107+
1108+
10491109
# NVFP4 GEMM (CUTLASS-based)
10501110
#
10511111
# Expects pre-swizzled scales in CUTLASS block-scaled layout (computed at

0 commit comments

Comments
 (0)