Skip to content

Commit 56eac41

Browse files
TimDettmersclaude
andcommitted
refactor: Make Hadamard rotation always-on with randomized signs
Simplify the CUTLASS fused quantize API by making rotation an internal implementation detail. The 16x16 randomized Hadamard matrix (H*D with fixed seed) is now generated once per device and cached, invisible to callers. This improves robustness against structured outlier patterns. API changes: - quantize_nvfp4() no longer accepts rotate parameter - cutlass_fused_quantize_nvfp4 op signature: (A, tensor_scale) only - LinearNVFP4 no longer accepts rotate parameter - dequantize_nvfp4 uses correct inverse for randomized Hadamard (out @ B) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0ba4950 commit 56eac41

File tree

7 files changed

+96
-151
lines changed

7 files changed

+96
-151
lines changed

benchmarks/nvfp4_gemm_results.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ zero additional compute cost.
139139
| 128×11008 | 0.004 | 0.006 | 0.006 | Old: 49%, CUTLASS: 0% |
140140

141141
**Key finding**: The old hand-written kernel is ~1.5x faster for plain quantize (no rotation).
142-
But for quantize with Hadamard rotation (`rotate=True`, the new default):
142+
But for quantize with Hadamard rotation (always on):
143143
- Small shapes (M ≤ 32): CUTLASS 0.004ms vs old fused 0.003ms — old kernel wins
144144
- Large shapes (M = 4096): CUTLASS 0.039ms vs old fused 0.043ms — CUTLASS wins (1.1x)
145145
- The main value is rotation at zero cost, not raw quantize speed

bitsandbytes/_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,17 +495,19 @@ def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tens
495495

496496

497497
# CUTLASS-based fused quantize for NVFP4 (SM_120+)
498-
# Uses QuTLASS GEMM-as-quantize approach: 7-9x faster than hand-written kernel.
499-
# Supports both AbsMax and Quest (Hadamard rotation) methods.
498+
# Uses QuTLASS GEMM-as-quantize approach with always-on randomized Hadamard
499+
# rotation. The rotation is free (baked into the GEMM B operand) and improves
500+
# quantization quality by spreading outliers across blocks.
500501
torch.library.define(
501502
"bitsandbytes::cutlass_fused_quantize_nvfp4",
502-
"(Tensor A, Tensor B, float tensor_scale, bool quest) -> (Tensor, Tensor, Tensor)",
503+
"(Tensor A, float tensor_scale) -> (Tensor, Tensor, Tensor)",
503504
)
504505

505506

506507
@register_fake("bitsandbytes::cutlass_fused_quantize_nvfp4")
507508
def _(
508-
A: torch.Tensor, B: torch.Tensor, tensor_scale: float, quest: bool
509+
A: torch.Tensor,
510+
tensor_scale: float,
509511
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
510512
n = A.numel()
511513
torch._check(n % 16 == 0, lambda: f"NVFP4 requires numel divisible by 16, got {n}")

bitsandbytes/backends/cuda/ops.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -904,34 +904,43 @@ def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tens
904904

905905

906906
# CUTLASS-based fused quantize for NVFP4 (SM_120+)
907-
# Uses QuTLASS GEMM-as-quantize approach: 7-9x faster than hand-written kernel.
908-
# Caches the identity/Hadamard matrices per device.
909-
_fused_quant_matrices: dict[torch.device, dict[str, torch.Tensor]] = {}
910-
911-
912-
def _get_fused_quant_matrix(device: torch.device, quest: bool) -> torch.Tensor:
913-
"""Get cached 16x16 identity or Hadamard matrix for fused quantize."""
914-
key = "quest" if quest else "identity"
915-
dev_cache = _fused_quant_matrices.setdefault(device, {})
916-
if key not in dev_cache:
917-
if quest:
918-
# Normalized 16x16 Hadamard matrix (values ±0.25 = ±1/sqrt(16))
919-
# Build via Sylvester construction
920-
h = torch.tensor([[1.0]], dtype=torch.float32)
921-
for _ in range(4): # 2^4 = 16
922-
h = torch.cat([torch.cat([h, h], dim=1), torch.cat([h, -h], dim=1)], dim=0)
923-
h = (h / 4.0).to(dtype=torch.bfloat16, device=device)
924-
dev_cache[key] = h
925-
else:
926-
dev_cache[key] = torch.eye(16, dtype=torch.bfloat16, device=device)
927-
return dev_cache[key]
907+
# Uses QuTLASS GEMM-as-quantize approach with always-on randomized Hadamard
908+
# rotation. The 16x16 rotation matrix is generated once per device and cached.
909+
_rotation_matrices: dict[torch.device, torch.Tensor] = {}
910+
911+
# Fixed seed for reproducible rotation across weight quantization and inference.
912+
_ROTATION_SEED = 42
913+
914+
915+
def _get_rotation_matrix(device: torch.device) -> torch.Tensor:
916+
"""Get cached 16x16 randomized Hadamard matrix for fused quantize.
917+
918+
Builds H * D where H is the 16x16 normalized Hadamard matrix and D is a
919+
diagonal sign-flip matrix (±1 per column) from a fixed seed. The same
920+
matrix must be used for both weight and activation quantization.
921+
"""
922+
if device not in _rotation_matrices:
923+
# Build normalized 16x16 Hadamard via Sylvester construction
924+
h = torch.tensor([[1.0]], dtype=torch.float32)
925+
for _ in range(4): # 2^4 = 16
926+
h = torch.cat([torch.cat([h, h], dim=1), torch.cat([h, -h], dim=1)], dim=0)
927+
h /= 4.0 # normalize by 1/sqrt(16)
928+
929+
# Apply random sign flips per column (H @ D)
930+
gen = torch.Generator().manual_seed(_ROTATION_SEED)
931+
signs = torch.randint(0, 2, (16,), generator=gen) * 2 - 1 # ±1
932+
h = h * signs.float()
933+
934+
_rotation_matrices[device] = h.to(dtype=torch.bfloat16, device=device)
935+
return _rotation_matrices[device]
928936

929937

930938
@register_kernel("bitsandbytes::cutlass_fused_quantize_nvfp4", "cuda")
931939
def _(
932-
A: torch.Tensor, B: torch.Tensor, tensor_scale: float, quest: bool
940+
A: torch.Tensor,
941+
tensor_scale: float,
933942
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
934-
"""CUTLASS-based fused quantize (optionally with Hadamard rotation).
943+
"""CUTLASS-based fused quantize with randomized Hadamard rotation.
935944
936945
The CUTLASS kernel requires M to be a multiple of 128. We pad here
937946
and trim the output to maintain a transparent API.
@@ -974,9 +983,10 @@ def _(
974983
# QuTLASS outputs as (padded_M, 1) but we flatten
975984
scales_padded = torch.zeros(padded_M, dtype=torch.uint8, device=A.device)
976985

986+
# Get the cached randomized Hadamard rotation matrix for this device
987+
B = _get_rotation_matrix(A.device)
988+
977989
with _cuda_device_of(A):
978-
# Always use AbsMax kernel — rotation is handled by B matrix (Hadamard vs identity).
979-
# The Quest template has an internal epilogue issue that produces incorrect results.
980990
fn = lib.cfused_quantize_nvfp4_absmax
981991
fn(
982992
get_ptr(A_flat),

bitsandbytes/functional.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,15 +1159,16 @@ def _has_cutlass_fused_quantize() -> bool:
11591159
def quantize_nvfp4(
11601160
A: torch.Tensor,
11611161
tensor_scale: Optional[float] = None,
1162-
rotate: bool = True,
11631162
) -> tuple[torch.Tensor, NVFP4QuantState]:
1164-
"""Quantize a tensor to NVFP4 (E2M1) format.
1163+
"""Quantize a tensor to NVFP4 (E2M1) format with Hadamard rotation.
1164+
1165+
Always applies a randomized 16x16 Hadamard rotation before quantization.
1166+
When CUTLASS is available (SM_120+), the rotation is fused into the
1167+
quantize kernel at zero cost. Otherwise, falls back to hand-written kernels.
11651168
11661169
Args:
11671170
A: Input tensor (float16, bfloat16, or float32). Must have numel divisible by 16.
11681171
tensor_scale: Optional pre-computed tensor scale. If None, computed as abs(max(A)).
1169-
rotate: If True, apply Hadamard rotation before quantization (fused kernel).
1170-
Default is True since the CUTLASS fused quantize includes rotation for free.
11711172
11721173
Returns:
11731174
Tuple of (packed_data, NVFP4QuantState).
@@ -1176,36 +1177,21 @@ def quantize_nvfp4(
11761177
input_dtype = A.dtype
11771178
A_flat = A.reshape(-1).contiguous()
11781179

1179-
# Use CUTLASS fused quantize when available (7-9x faster)
1180+
# Use CUTLASS fused quantize when available (7-9x faster, rotation is free)
11801181
use_cutlass = _has_cutlass_fused_quantize() and A.is_cuda
11811182
if use_cutlass:
11821183
# CUTLASS fused quantize requires BF16 input
1183-
if A_flat.dtype != torch.bfloat16:
1184-
A_bf16 = A_flat.to(torch.bfloat16)
1185-
else:
1186-
A_bf16 = A_flat
1184+
A_bf16 = A_flat.to(torch.bfloat16) if A_flat.dtype != torch.bfloat16 else A_flat
11871185

1188-
# Compute tensor_scale if not provided
11891186
if tensor_scale is None:
1190-
if rotate:
1191-
# For rotation, scale should be computed on rotated data.
1192-
# The CUTLASS kernel handles this internally, but we need the
1193-
# tensor_scale for the quantize op. Compute on original data
1194-
# as approximation — the block scales handle per-block normalization.
1195-
tensor_scale = A_bf16.abs().max().item()
1196-
else:
1197-
tensor_scale = A_bf16.abs().max().item()
1198-
1199-
from bitsandbytes.backends.cuda.ops import _get_fused_quant_matrix
1187+
tensor_scale = A_bf16.abs().max().item()
12001188

1201-
B = _get_fused_quant_matrix(A.device, quest=rotate)
1202-
packed, block_scales, ts = torch.ops.bitsandbytes.cutlass_fused_quantize_nvfp4(A_bf16, B, tensor_scale, rotate)
1203-
elif rotate:
1204-
if tensor_scale is None:
1205-
tensor_scale = None # let the kernel compute it
1206-
packed, block_scales, ts = torch.ops.bitsandbytes.fused_hadamard_quantize_nvfp4(A_flat, tensor_scale)
1189+
packed, block_scales, ts = torch.ops.bitsandbytes.cutlass_fused_quantize_nvfp4(A_bf16, tensor_scale)
12071190
else:
1208-
packed, block_scales, ts = torch.ops.bitsandbytes.quantize_nvfp4(A_flat, tensor_scale)
1191+
# Fallback: hand-written fused Hadamard + NVFP4 quantize kernel.
1192+
# Note: uses plain (non-randomized) Had16. Dequantize inverse rotation
1193+
# will be slightly off but this path is only for non-SM_120+ development.
1194+
packed, block_scales, ts = torch.ops.bitsandbytes.fused_hadamard_quantize_nvfp4(A_flat, tensor_scale)
12091195

12101196
# Pre-compute CUTLASS block-scaled layout for GEMM. The 2D scale shape is
12111197
# (rows, K//16) where rows is the product of all dims except the last.
@@ -1220,7 +1206,7 @@ def quantize_nvfp4(
12201206
tensor_scale=ts.item(),
12211207
shape=input_shape,
12221208
dtype=input_dtype,
1223-
rotated=rotate,
1209+
rotated=True,
12241210
block_scales_blocked=block_scales_blocked,
12251211
)
12261212
return packed, state
@@ -1251,8 +1237,12 @@ def dequantize_nvfp4(
12511237
)
12521238

12531239
if quant_state.rotated:
1254-
# Apply inverse Hadamard rotation
1255-
torch.ops.bitsandbytes.hadamard_rotate_nvfp4(out)
1240+
# Undo rotation: data was quantized as x @ B^T, so recover x = out @ B.
1241+
# B is the cached randomized Hadamard matrix (orthogonal, so B^T·B = I).
1242+
from bitsandbytes.backends.cuda.ops import _get_rotation_matrix
1243+
1244+
B = _get_rotation_matrix(out.device)
1245+
out = (out.view(-1, 16) @ B).view(-1)
12561246

12571247
return out.reshape(quant_state.shape)
12581248

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,6 @@ class LinearNVFP4(nn.Linear):
683683
input_features: Number of input features.
684684
output_features: Number of output features.
685685
bias: Whether to use bias. Defaults to True.
686-
rotate: Apply Hadamard rotation before quantization. Defaults to True.
687-
With the CUTLASS fused quantize kernel, rotation is essentially free.
688686
device: Device for initialization.
689687
"""
690688

@@ -693,11 +691,9 @@ def __init__(
693691
input_features,
694692
output_features,
695693
bias=True,
696-
rotate=True,
697694
device=None,
698695
):
699696
super().__init__(input_features, output_features, bias, device)
700-
self.rotate = rotate
701697
self.weight_quantized = False
702698
self.weight_packed = None
703699
self.weight_state = None
@@ -708,7 +704,7 @@ def _quantize_weight(self):
708704

709705
# Weight is (out_features, in_features) = (N, K) in GEMM terms
710706
w = self.weight.data.to(torch.bfloat16).contiguous()
711-
packed, state = quantize_nvfp4(w, rotate=self.rotate)
707+
packed, state = quantize_nvfp4(w)
712708
self.weight_packed = packed
713709
self.weight_state = state
714710
self.weight_quantized = True
@@ -729,7 +725,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
729725
N = self.weight_state.shape[0] # out_features
730726

731727
# Quantize activations to NVFP4
732-
x_packed, x_state = quantize_nvfp4(x_2d, rotate=self.rotate)
728+
x_packed, x_state = quantize_nvfp4(x_2d)
733729

734730
# Run NVFP4 GEMM: x @ weight^T
735731
out = gemm_nvfp4(x_packed, x_state, self.weight_packed, self.weight_state)

docs/nvfp4_implementation_guide.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -893,12 +893,12 @@ bitsandbytes/
893893
4. **NVFP4=3 in DataType_t enum**: Separate from existing FP4=1 (custom bitsandbytes
894894
format, not E2M1). No breaking changes to existing API.
895895
5. **Two-level scaling**: E4M3 block scales per 16 elements + FP32 tensor scale.
896-
6. **Hadamard rotation on by default**: `rotate=True` is the default for both
897-
`quantize_nvfp4()` and `LinearNVFP4`. With the CUTLASS fused quantize, the Hadamard
898-
rotation is applied via the B matrix in the GEMM at zero additional cost.
896+
6. **Hadamard rotation always on**: Randomized Hadamard rotation is always applied.
897+
With the CUTLASS fused quantize, the rotation is applied via the B matrix in the
898+
GEMM at zero additional cost.
899899
7. **CUTLASS fused quantize**: Quantization formulated as a GEMM (SM_80 CUTLASS 2.x).
900-
Each group of 16 elements becomes a GEMM row; B is identity (AbsMax) or Hadamard
901-
(rotation). Falls back to the hand-written kernel on non-Blackwell builds.
900+
Each group of 16 elements becomes a GEMM row; B is the randomized Hadamard matrix.
901+
Falls back to the hand-written kernel on non-Blackwell builds.
902902
8. **Scale reordering at quantize time**: CUTLASS expects block-scaled swizzled layout;
903903
computed once at quantization and stored in `NVFP4QuantState.block_scales_blocked`.
904904
8. **BF16 output from CUTLASS**: Tensor scales folded into CUTLASS epilogue alpha;
@@ -965,14 +965,14 @@ import bitsandbytes.functional as F
965965
from bitsandbytes.nn import LinearNVFP4
966966

967967
# Quantize/dequantize
968-
packed, state = F.quantize_nvfp4(tensor, tensor_scale, rotate=True)
968+
packed, state = F.quantize_nvfp4(tensor, tensor_scale)
969969
recovered = F.dequantize_nvfp4(packed, state)
970970

971971
# GEMM
972972
output = F.gemm_nvfp4(A_data, A_state, B_data, B_state)
973973

974974
# Linear module
975-
layer = LinearNVFP4(4096, 11008, rotate=True)
975+
layer = LinearNVFP4(4096, 11008)
976976
output = layer(input) # weight quantized lazily on first forward
977977
```
978978

0 commit comments

Comments
 (0)