Skip to content

Commit f3916d4

Browse files
TimDettmersclaude
andcommitted
feat: Add LinearNVFP4 module for Blackwell GPU inference
Implements LinearNVFP4(nn.Linear) that quantizes weights to NVFP4 on first forward pass and uses the block-scaled MMA for inference. Features: - Lazy weight quantization (on first forward) - Optional Hadamard rotation (rotate=True) - Activation quantization in the forward pass - NVFP4 GEMM via hardware MMA instruction - Automatic input reshape for batched inputs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 45645cf commit f3916d4

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

bitsandbytes/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Linear8bitLt,
1414
LinearFP4,
1515
LinearNF4,
16+
LinearNVFP4,
1617
OutlierAwareLinear,
1718
Params4bit,
1819
StableEmbedding,

bitsandbytes/nn/modules.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,79 @@ def __init__(
672672
)
673673

674674

675+
class LinearNVFP4(nn.Linear):
676+
"""NVFP4 (E2M1) quantized linear layer for Blackwell GPUs (SM_120).
677+
678+
Quantizes weights to NVFP4 on first forward pass. Uses the hardware
679+
block-scaled MMA instruction for inference. Supports optional Hadamard
680+
rotation for improved accuracy.
681+
682+
Args:
683+
input_features: Number of input features.
684+
output_features: Number of output features.
685+
bias: Whether to use bias. Defaults to True.
686+
rotate: Apply Hadamard rotation before quantization. Defaults to False.
687+
device: Device for initialization.
688+
"""
689+
690+
def __init__(
691+
self,
692+
input_features,
693+
output_features,
694+
bias=True,
695+
rotate=False,
696+
device=None,
697+
):
698+
super().__init__(input_features, output_features, bias, device)
699+
self.rotate = rotate
700+
self.weight_quantized = False
701+
self.weight_packed = None
702+
self.weight_state = None
703+
704+
def _quantize_weight(self):
705+
"""Quantize the weight tensor to NVFP4."""
706+
from bitsandbytes.functional import quantize_nvfp4
707+
708+
# Weight is (out_features, in_features) = (N, K) in GEMM terms
709+
w = self.weight.data.float().contiguous()
710+
packed, state = quantize_nvfp4(w, rotate=self.rotate)
711+
self.weight_packed = packed
712+
self.weight_state = state
713+
self.weight_quantized = True
714+
# Free the original weight to save memory
715+
self.weight = nn.Parameter(torch.empty(0, device=w.device, dtype=w.dtype), requires_grad=False)
716+
717+
def forward(self, x: torch.Tensor) -> torch.Tensor:
718+
if not self.weight_quantized:
719+
self._quantize_weight()
720+
721+
from bitsandbytes.functional import dequantize_nvfp4, gemm_nvfp4, quantize_nvfp4
722+
723+
inp_dtype = x.dtype
724+
input_shape = x.shape
725+
726+
# Reshape input: (*, K) -> (M, K)
727+
x_2d = x.reshape(-1, input_shape[-1]).float().contiguous()
728+
M = x_2d.shape[0]
729+
K = x_2d.shape[1]
730+
N = self.weight_state.shape[0] # out_features
731+
732+
# Quantize activations to NVFP4
733+
x_packed, x_state = quantize_nvfp4(x_2d, rotate=self.rotate)
734+
735+
# Run NVFP4 GEMM: x @ weight^T
736+
out = gemm_nvfp4(x_packed, x_state, self.weight_packed, self.weight_state)
737+
738+
# Reshape output back: (M, N) -> (*, N)
739+
out = out.reshape(*input_shape[:-1], N)
740+
741+
# Add bias
742+
if self.bias is not None:
743+
out = out + self.bias.to(out.dtype)
744+
745+
return out.to(inp_dtype)
746+
747+
675748
class Int8Params(torch.nn.Parameter):
676749
def __new__(
677750
cls,

0 commit comments

Comments
 (0)