Skip to content

Commit 6d5dac1

Browse files
TimDettmersclaude
andcommitted
perf: Add graph-safe _impl variants for fused quantize and GEMM
Extract core logic into _fused_quantize_nvfp4_impl and _gemm_nvfp4_impl that accept optional pre-allocated output buffers. When provided, zero allocations occur, making them safe for CUDA graph capture. The existing registered ops remain unchanged for convenience. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 53fec13 commit 6d5dac1

File tree

1 file changed

+88
-32
lines changed
  • bitsandbytes/backends/cuda

1 file changed

+88
-32
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -936,15 +936,25 @@ def _get_rotation_matrix(device: torch.device) -> torch.Tensor:
936936
return _rotation_matrices[device]
937937

938938

939-
@register_kernel("bitsandbytes::cutlass_fused_quantize_nvfp4", "cuda")
940-
def _(
939+
def _fused_quantize_nvfp4_impl(
941940
A: torch.Tensor,
942941
tensor_scale: float,
942+
packed_out: Optional[torch.Tensor] = None,
943+
scales_out: Optional[torch.Tensor] = None,
944+
global_scale_buf: Optional[torch.Tensor] = None,
943945
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
944-
"""CUTLASS-based fused quantize with randomized Hadamard rotation.
945-
946-
The CUTLASS kernel requires M to be a multiple of 128. We pad here
947-
and trim the output to maintain a transparent API.
946+
"""Core CUTLASS fused quantize implementation.
947+
948+
When output buffers are provided, no allocations occur — safe for CUDA
949+
graph capture. When None, buffers are allocated (convenient but not
950+
graph-safe).
951+
952+
Args:
953+
A: BF16 input, numel must be divisible by 16.
954+
tensor_scale: Global tensor scale.
955+
packed_out: Pre-allocated uint8 output (padded_M * 8 bytes). None to allocate.
956+
scales_out: Pre-allocated uint8 scales (padded_M bytes). None to allocate.
957+
global_scale_buf: Pre-allocated float32 scalar buffer. None to allocate.
948958
"""
949959
A = A.contiguous()
950960
n = A.numel()
@@ -954,10 +964,8 @@ def _(
954964
lambda: f"CUTLASS fused quantize requires bfloat16, got {A.dtype}",
955965
)
956966

957-
# Reshape to 2D: (M, K) where K is the last dimension
958-
# The fused quantize GEMM treats each group of 16 elements as one "row"
959-
K = 16 # NVFP4 group size = GEMM K dimension
960-
N = 16 # B matrix is 16x16
967+
K = 16
968+
N = 16
961969
orig_M = n // K
962970
padded_M = ((orig_M + 127) // 128) * 128
963971

@@ -970,26 +978,26 @@ def _(
970978
else:
971979
A_flat = A
972980

973-
# Compute global_scale = 1/tensor_scale (QuTLASS convention)
974-
global_scale = torch.tensor(
975-
[1.0 / tensor_scale if tensor_scale > 0 else 0.0],
976-
dtype=torch.float32,
977-
device=A.device,
978-
)
979-
980-
# Allocate output buffers (padded size)
981-
packed_padded = torch.zeros(padded_M * K // 2, dtype=torch.uint8, device=A.device)
981+
# Use pre-allocated buffers or allocate new ones
982+
if global_scale_buf is not None:
983+
global_scale_buf.fill_(1.0 / tensor_scale if tensor_scale > 0 else 0.0)
984+
global_scale = global_scale_buf
985+
else:
986+
global_scale = torch.tensor(
987+
[1.0 / tensor_scale if tensor_scale > 0 else 0.0],
988+
dtype=torch.float32,
989+
device=A.device,
990+
)
982991

983-
# Scale output: one E4M3 scale per 16-element block = padded_M scales
984-
# QuTLASS outputs as (padded_M, 1) but we flatten
985-
scales_padded = torch.zeros(padded_M, dtype=torch.uint8, device=A.device)
992+
packed_padded = (
993+
packed_out if packed_out is not None else torch.zeros(padded_M * K // 2, dtype=torch.uint8, device=A.device)
994+
)
995+
scales_padded = scales_out if scales_out is not None else torch.zeros(padded_M, dtype=torch.uint8, device=A.device)
986996

987-
# Get the cached randomized Hadamard rotation matrix for this device
988997
B = _get_rotation_matrix(A.device)
989998

990999
with _cuda_device_of(A):
991-
fn = lib.cfused_quantize_nvfp4_absmax
992-
fn(
1000+
lib.cfused_quantize_nvfp4_absmax(
9931001
get_ptr(A_flat),
9941002
get_ptr(B),
9951003
get_ptr(packed_padded),
@@ -1001,14 +1009,22 @@ def _(
10011009
_get_tensor_stream(A),
10021010
)
10031011

1004-
# Trim to original size
10051012
packed = packed_padded[: orig_M * K // 2] if padded_M != orig_M else packed_padded
10061013
block_scales = scales_padded[:orig_M] if padded_M != orig_M else scales_padded
10071014

10081015
ts_out = torch.tensor([tensor_scale], dtype=torch.float32, device=A.device)
10091016
return packed, block_scales, ts_out
10101017

10111018

1019+
@register_kernel("bitsandbytes::cutlass_fused_quantize_nvfp4", "cuda")
1020+
def _(
1021+
A: torch.Tensor,
1022+
tensor_scale: float,
1023+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1024+
"""CUTLASS-based fused quantize with randomized Hadamard rotation."""
1025+
return _fused_quantize_nvfp4_impl(A, tensor_scale)
1026+
1027+
10121028
# Scale reordering for CUTLASS block-scaled GEMM
10131029
@register_kernel("bitsandbytes::scale_to_blocked", "cuda")
10141030
def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
@@ -1038,8 +1054,7 @@ def _(scales: torch.Tensor, H: int, W: int) -> torch.Tensor:
10381054
# quantization time by scale_to_blocked). Tensor scales are folded into
10391055
# the CUTLASS epilogue alpha. Output is BF16, converted to FP32 for
10401056
# API compatibility.
1041-
@register_kernel("bitsandbytes::gemm_nvfp4", "cuda")
1042-
def _(
1057+
def _gemm_nvfp4_impl(
10431058
A_packed: torch.Tensor,
10441059
B_packed: torch.Tensor,
10451060
A_scales: torch.Tensor,
@@ -1049,12 +1064,27 @@ def _(
10491064
M: int,
10501065
N: int,
10511066
K: int,
1067+
D_out: Optional[torch.Tensor] = None,
1068+
alpha_buf: Optional[torch.Tensor] = None,
10521069
) -> torch.Tensor:
1070+
"""Core NVFP4 GEMM implementation.
1071+
1072+
When D_out and alpha_buf are provided, no allocations occur — safe for
1073+
CUDA graph capture. When None, buffers are allocated.
1074+
1075+
Args:
1076+
D_out: Pre-allocated BF16 output (M, N). None to allocate.
1077+
alpha_buf: Pre-allocated float32 scalar buffer. None to allocate.
1078+
"""
10531079
with _cuda_device_of(A_packed):
1054-
# A_scales and B_scales are already in CUTLASS block-scaled layout
1055-
# (pre-computed at quantization time by scale_to_blocked)
1056-
alpha = torch.tensor([A_tensor_scale * B_tensor_scale], dtype=torch.float32, device=A_packed.device)
1057-
D_out = torch.empty(M, N, dtype=torch.bfloat16, device=A_packed.device)
1080+
if alpha_buf is not None:
1081+
alpha_buf.fill_(A_tensor_scale * B_tensor_scale)
1082+
alpha = alpha_buf
1083+
else:
1084+
alpha = torch.tensor([A_tensor_scale * B_tensor_scale], dtype=torch.float32, device=A_packed.device)
1085+
1086+
if D_out is None:
1087+
D_out = torch.empty(M, N, dtype=torch.bfloat16, device=A_packed.device)
10581088

10591089
lib.cgemm_nvfp4_cutlass(
10601090
get_ptr(A_packed),
@@ -1070,3 +1100,29 @@ def _(
10701100
)
10711101

10721102
return D_out.float()
1103+
1104+
1105+
@register_kernel("bitsandbytes::gemm_nvfp4", "cuda")
1106+
def _(
1107+
A_packed: torch.Tensor,
1108+
B_packed: torch.Tensor,
1109+
A_scales: torch.Tensor,
1110+
B_scales: torch.Tensor,
1111+
A_tensor_scale: float,
1112+
B_tensor_scale: float,
1113+
M: int,
1114+
N: int,
1115+
K: int,
1116+
) -> torch.Tensor:
1117+
"""NVFP4 GEMM: A @ B^T with block-scaled FP4 inputs."""
1118+
return _gemm_nvfp4_impl(
1119+
A_packed,
1120+
B_packed,
1121+
A_scales,
1122+
B_scales,
1123+
A_tensor_scale,
1124+
B_tensor_scale,
1125+
M,
1126+
N,
1127+
K,
1128+
)

0 commit comments

Comments
 (0)