Skip to content

Commit 45645cf

Browse files
TimDettmersclaude
andcommitted
feat: Add NVFP4 Python API (torch.library ops, backend dispatch, functional)
Adds complete Python integration for NVFP4: 1. torch.library op definitions (_ops.py): - quantize_nvfp4, dequantize_nvfp4 - hadamard_rotate_nvfp4, fused_hadamard_quantize_nvfp4 - gemm_nvfp4 With register_fake implementations for torch.compile compatibility. 2. CUDA backend dispatch (backends/cuda/ops.py): - All ops dispatch to the C library via ctypes - GEMM applies tensor scales to the raw kernel output 3. Functional API (functional.py): - NVFP4QuantState class (packed_data, block_scales, tensor_scale) - quantize_nvfp4(A, tensor_scale, rotate) -> (packed, state) - dequantize_nvfp4(packed, state) -> tensor - gemm_nvfp4(A_data, A_state, B_data, B_state) -> tensor Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 47f4af3 commit 45645cf

File tree

3 files changed

+343
-0
lines changed

3 files changed

+343
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,90 @@ def _(
431431
qmap2.dtype == absmax2.dtype == torch.float32,
432432
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
433433
)
434+
435+
436+
# NVFP4 quantization (E2M1 with two-level scaling: E4M3 block scales + FP32 tensor scale)
437+
torch.library.define(
438+
"bitsandbytes::quantize_nvfp4",
439+
"(Tensor A, float? tensor_scale) -> (Tensor, Tensor, Tensor)",
440+
)
441+
442+
443+
@register_fake("bitsandbytes::quantize_nvfp4")
444+
def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
445+
n = A.numel()
446+
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")
447+
packed = torch.empty(n // 2, dtype=torch.uint8, device=A.device)
448+
block_scales = torch.empty(n // 16, dtype=torch.uint8, device=A.device)
449+
ts_out = torch.empty(1, dtype=torch.float32, device=A.device)
450+
return packed, block_scales, ts_out
451+
452+
453+
# NVFP4 dequantization
454+
torch.library.define(
455+
"bitsandbytes::dequantize_nvfp4",
456+
"(Tensor packed, Tensor block_scales, float tensor_scale, int numel, ScalarType dtype) -> Tensor",
457+
)
458+
459+
460+
@register_fake("bitsandbytes::dequantize_nvfp4")
461+
def _(
462+
packed: torch.Tensor, block_scales: torch.Tensor, tensor_scale: float, numel: int, dtype: torch.dtype
463+
) -> torch.Tensor:
464+
return torch.empty(numel, dtype=dtype, device=packed.device)
465+
466+
467+
# NVFP4 Hadamard rotation (in-place)
468+
torch.library.define(
469+
"bitsandbytes::hadamard_rotate_nvfp4",
470+
"(Tensor(a!) A) -> ()",
471+
)
472+
473+
474+
@register_fake("bitsandbytes::hadamard_rotate_nvfp4")
475+
def _(A: torch.Tensor) -> None:
476+
n = A.numel()
477+
torch._check(n % 16 == 0, lambda: f"Hadamard rotation requires numel divisible by 16, got {n}")
478+
479+
480+
# Fused Hadamard rotation + NVFP4 quantize
481+
torch.library.define(
482+
"bitsandbytes::fused_hadamard_quantize_nvfp4",
483+
"(Tensor A, float? tensor_scale) -> (Tensor, Tensor, Tensor)",
484+
)
485+
486+
487+
@register_fake("bitsandbytes::fused_hadamard_quantize_nvfp4")
488+
def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
489+
n = A.numel()
490+
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")
491+
packed = torch.empty(n // 2, dtype=torch.uint8, device=A.device)
492+
block_scales = torch.empty(n // 16, dtype=torch.uint8, device=A.device)
493+
ts_out = torch.empty(1, dtype=torch.float32, device=A.device)
494+
return packed, block_scales, ts_out
495+
496+
497+
# NVFP4 GEMM (A @ B^T with block-scaled FP4 inputs)
498+
torch.library.define(
499+
"bitsandbytes::gemm_nvfp4",
500+
"(Tensor A_packed, Tensor B_packed, Tensor A_scales, Tensor B_scales, "
501+
"float A_tensor_scale, float B_tensor_scale, int M, int N, int K) -> Tensor",
502+
)
503+
504+
505+
@register_fake("bitsandbytes::gemm_nvfp4")
506+
def _(
507+
A_packed: torch.Tensor,
508+
B_packed: torch.Tensor,
509+
A_scales: torch.Tensor,
510+
B_scales: torch.Tensor,
511+
A_tensor_scale: float,
512+
B_tensor_scale: float,
513+
M: int,
514+
N: int,
515+
K: int,
516+
) -> torch.Tensor:
517+
torch._check_is_size(M)
518+
torch._check_is_size(N)
519+
torch._check_is_size(K)
520+
return torch.empty(M, N, dtype=torch.float32, device=A_packed.device)

bitsandbytes/backends/cuda/ops.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,127 @@ def _optimizer_update_8bit_blockwise_impl(
772772

773773
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
774774
register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)
775+
776+
777+
# NVFP4 quantization
778+
@register_kernel("bitsandbytes::quantize_nvfp4", "cuda")
779+
def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
780+
A = A.contiguous()
781+
n = A.numel()
782+
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")
783+
torch._check(
784+
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
785+
lambda: f"NVFP4 quantization requires float16/bfloat16/float32, got {A.dtype}",
786+
)
787+
788+
if tensor_scale is None:
789+
tensor_scale = A.abs().max().item()
790+
791+
packed = torch.zeros(n // 2, dtype=torch.uint8, device=A.device)
792+
block_scales = torch.zeros(n // 16, dtype=torch.uint8, device=A.device)
793+
794+
with _cuda_device_of(A):
795+
if A.dtype == torch.float16:
796+
lib.cquantize_nvfp4_fp16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
797+
elif A.dtype == torch.bfloat16:
798+
lib.cquantize_nvfp4_bf16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
799+
else:
800+
lib.cquantize_nvfp4_fp32(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
801+
802+
ts_out = torch.tensor([tensor_scale], dtype=torch.float32, device=A.device)
803+
return packed, block_scales, ts_out
804+
805+
806+
# NVFP4 dequantization
807+
@register_kernel("bitsandbytes::dequantize_nvfp4", "cuda")
808+
def _(
809+
packed: torch.Tensor, block_scales: torch.Tensor, tensor_scale: float, numel: int, dtype: torch.dtype
810+
) -> torch.Tensor:
811+
packed = packed.contiguous()
812+
block_scales = block_scales.contiguous()
813+
output = torch.zeros(numel, dtype=dtype, device=packed.device)
814+
815+
with _cuda_device_of(packed):
816+
if dtype == torch.float16:
817+
lib.cdequantize_nvfp4_fp16(get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), get_ptr(output), ct.c_int(numel), ct.c_void_p(0))
818+
elif dtype == torch.bfloat16:
819+
lib.cdequantize_nvfp4_bf16(get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), get_ptr(output), ct.c_int(numel), ct.c_void_p(0))
820+
else:
821+
lib.cdequantize_nvfp4_fp32(get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), get_ptr(output), ct.c_int(numel), ct.c_void_p(0))
822+
823+
return output
824+
825+
826+
# NVFP4 Hadamard rotation (in-place)
827+
@register_kernel("bitsandbytes::hadamard_rotate_nvfp4", "cuda")
828+
def _(A: torch.Tensor) -> None:
829+
A_contig = A.contiguous()
830+
n = A_contig.numel()
831+
torch._check(n % 16 == 0, lambda: f"Hadamard rotation requires numel divisible by 16, got {n}")
832+
833+
with _cuda_device_of(A_contig):
834+
if A_contig.dtype == torch.float16:
835+
lib.chadamard_rotate16_fp16(get_ptr(A_contig), ct.c_int(n))
836+
elif A_contig.dtype == torch.bfloat16:
837+
lib.chadamard_rotate16_bf16(get_ptr(A_contig), ct.c_int(n))
838+
else:
839+
lib.chadamard_rotate16_fp32(get_ptr(A_contig), ct.c_int(n))
840+
841+
if not A.is_contiguous():
842+
A.copy_(A_contig)
843+
844+
845+
# Fused Hadamard rotation + NVFP4 quantize
846+
@register_kernel("bitsandbytes::fused_hadamard_quantize_nvfp4", "cuda")
847+
def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
848+
A = A.contiguous()
849+
n = A.numel()
850+
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")
851+
852+
if tensor_scale is None:
853+
# Compute scale on rotated data
854+
A_copy = A.clone()
855+
torch.ops.bitsandbytes.hadamard_rotate_nvfp4(A_copy)
856+
tensor_scale = A_copy.abs().max().item()
857+
858+
packed = torch.zeros(n // 2, dtype=torch.uint8, device=A.device)
859+
block_scales = torch.zeros(n // 16, dtype=torch.uint8, device=A.device)
860+
861+
with _cuda_device_of(A):
862+
if A.dtype == torch.float16:
863+
lib.cfused_hadamard_quantize_nvfp4_fp16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
864+
elif A.dtype == torch.bfloat16:
865+
lib.cfused_hadamard_quantize_nvfp4_bf16(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
866+
else:
867+
lib.cfused_hadamard_quantize_nvfp4_fp32(get_ptr(A), get_ptr(packed), get_ptr(block_scales), ct.c_float(tensor_scale), ct.c_int(n))
868+
869+
ts_out = torch.tensor([tensor_scale], dtype=torch.float32, device=A.device)
870+
return packed, block_scales, ts_out
871+
872+
873+
# NVFP4 GEMM
874+
@register_kernel("bitsandbytes::gemm_nvfp4", "cuda")
875+
def _(
876+
A_packed: torch.Tensor,
877+
B_packed: torch.Tensor,
878+
A_scales: torch.Tensor,
879+
B_scales: torch.Tensor,
880+
A_tensor_scale: float,
881+
B_tensor_scale: float,
882+
M: int,
883+
N: int,
884+
K: int,
885+
) -> torch.Tensor:
886+
D_out = torch.zeros(M, N, dtype=torch.float32, device=A_packed.device)
887+
888+
with _cuda_device_of(A_packed):
889+
lib.cgemm_nvfp4(
890+
get_ptr(A_packed), get_ptr(B_packed),
891+
get_ptr(A_scales), get_ptr(B_scales),
892+
get_ptr(D_out),
893+
ct.c_int(M), ct.c_int(N), ct.c_int(K),
894+
)
895+
896+
# Apply tensor scales (the GEMM kernel operates on raw quantized values)
897+
D_out.mul_(A_tensor_scale * B_tensor_scale)
898+
return D_out

bitsandbytes/functional.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,138 @@ def dequantize_4bit(
10771077
return out
10781078

10791079

1080+
# ---------------------------------------------------------------------------
1081+
# NVFP4 (E2M1) quantization with two-level scaling
1082+
# ---------------------------------------------------------------------------
1083+
1084+
1085+
class NVFP4QuantState:
1086+
"""Quantization state for NVFP4 (E2M1 format with block scales).
1087+
1088+
Stores the quantized data, E4M3 block scales (per 16 elements),
1089+
FP32 tensor scale, and metadata needed for dequantization.
1090+
"""
1091+
1092+
def __init__(
1093+
self,
1094+
packed_data: torch.Tensor,
1095+
block_scales: torch.Tensor,
1096+
tensor_scale: float,
1097+
shape: tuple,
1098+
dtype: torch.dtype,
1099+
rotated: bool = False,
1100+
):
1101+
self.packed_data = packed_data
1102+
self.block_scales = block_scales
1103+
self.tensor_scale = tensor_scale
1104+
self.shape = shape
1105+
self.dtype = dtype
1106+
self.rotated = rotated
1107+
1108+
def to(self, device):
1109+
return NVFP4QuantState(
1110+
packed_data=self.packed_data.to(device),
1111+
block_scales=self.block_scales.to(device),
1112+
tensor_scale=self.tensor_scale,
1113+
shape=self.shape,
1114+
dtype=self.dtype,
1115+
rotated=self.rotated,
1116+
)
1117+
1118+
1119+
def quantize_nvfp4(
1120+
A: torch.Tensor,
1121+
tensor_scale: Optional[float] = None,
1122+
rotate: bool = False,
1123+
) -> tuple[torch.Tensor, NVFP4QuantState]:
1124+
"""Quantize a tensor to NVFP4 (E2M1) format.
1125+
1126+
Args:
1127+
A: Input tensor (float16, bfloat16, or float32). Must have numel divisible by 16.
1128+
tensor_scale: Optional pre-computed tensor scale. If None, computed as abs(max(A)).
1129+
rotate: If True, apply Hadamard rotation before quantization (fused kernel).
1130+
1131+
Returns:
1132+
Tuple of (packed_data, NVFP4QuantState).
1133+
"""
1134+
input_shape = A.shape
1135+
input_dtype = A.dtype
1136+
A_flat = A.reshape(-1).contiguous()
1137+
1138+
if rotate:
1139+
packed, block_scales, ts = torch.ops.bitsandbytes.fused_hadamard_quantize_nvfp4(A_flat, tensor_scale)
1140+
else:
1141+
packed, block_scales, ts = torch.ops.bitsandbytes.quantize_nvfp4(A_flat, tensor_scale)
1142+
1143+
state = NVFP4QuantState(
1144+
packed_data=packed,
1145+
block_scales=block_scales,
1146+
tensor_scale=ts.item(),
1147+
shape=input_shape,
1148+
dtype=input_dtype,
1149+
rotated=rotate,
1150+
)
1151+
return packed, state
1152+
1153+
1154+
def dequantize_nvfp4(
1155+
packed_data: torch.Tensor,
1156+
quant_state: NVFP4QuantState,
1157+
out_dtype: Optional[torch.dtype] = None,
1158+
) -> torch.Tensor:
1159+
"""Dequantize NVFP4 packed data back to floating point.
1160+
1161+
Args:
1162+
packed_data: Packed FP4 data (uint8, 2 values per byte).
1163+
quant_state: Quantization state from quantize_nvfp4.
1164+
out_dtype: Output dtype. Defaults to the original dtype.
1165+
1166+
Returns:
1167+
Dequantized tensor with the original shape.
1168+
"""
1169+
dtype = out_dtype or quant_state.dtype
1170+
numel = 1
1171+
for s in quant_state.shape:
1172+
numel *= s
1173+
1174+
out = torch.ops.bitsandbytes.dequantize_nvfp4(
1175+
packed_data, quant_state.block_scales, quant_state.tensor_scale, numel, dtype
1176+
)
1177+
1178+
if quant_state.rotated:
1179+
# Apply inverse Hadamard rotation
1180+
torch.ops.bitsandbytes.hadamard_rotate_nvfp4(out)
1181+
1182+
return out.reshape(quant_state.shape)
1183+
1184+
1185+
def gemm_nvfp4(
1186+
A_data: torch.Tensor,
1187+
A_state: NVFP4QuantState,
1188+
B_data: torch.Tensor,
1189+
B_state: NVFP4QuantState,
1190+
) -> torch.Tensor:
1191+
"""NVFP4 GEMM: compute A @ B^T using block-scaled FP4 inputs.
1192+
1193+
Args:
1194+
A_data: Packed FP4 data for A (M*K/2 bytes).
1195+
A_state: Quantization state for A (M x K).
1196+
B_data: Packed FP4 data for B (N*K/2 bytes, stored as N rows of K).
1197+
B_state: Quantization state for B (N x K).
1198+
1199+
Returns:
1200+
Output tensor of shape (M, N) in float32 with tensor scales applied.
1201+
"""
1202+
M = A_state.shape[0]
1203+
K = A_state.shape[1]
1204+
N = B_state.shape[0]
1205+
1206+
return torch.ops.bitsandbytes.gemm_nvfp4(
1207+
A_data, B_data, A_state.block_scales, B_state.block_scales,
1208+
A_state.tensor_scale, B_state.tensor_scale, M, N, K,
1209+
)
1210+
1211+
10801212
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
10811213
def quantize(
10821214
A: Tensor,

0 commit comments

Comments
 (0)