Skip to content

Commit 53fec13

Browse files
TimDettmersclaude
andcommitted
fix: Correct dequant inverse for CUTLASS GEMM convention
The CUTLASS fused quantize GEMM computes A @ R (no transpose on the rotation matrix R). The dequant inverse must therefore apply R^T, not R. This was masked with the plain Hadamard (which is symmetric, H = H^T) but broke with the randomized Hadamard (R ≠ R^T). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 56eac41 commit 53fec13

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -915,9 +915,10 @@ def _(A: torch.Tensor, tensor_scale: Optional[float] = None) -> tuple[torch.Tens
915915
def _get_rotation_matrix(device: torch.device) -> torch.Tensor:
916916
"""Get cached 16x16 randomized Hadamard matrix for fused quantize.
917917
918-
Builds H * D where H is the 16x16 normalized Hadamard matrix and D is a
919-
diagonal sign-flip matrix (±1 per column) from a fixed seed. The same
920-
matrix must be used for both weight and activation quantization.
918+
Builds R = H * D where H is the 16x16 normalized Hadamard matrix and D is
919+
a diagonal sign-flip matrix (±1 per column) from a fixed seed. The CUTLASS
920+
GEMM computes ``A @ R`` (no transpose), so dequant must apply ``@ R^T``.
921+
The same matrix must be used for both weight and activation quantization.
921922
"""
922923
if device not in _rotation_matrices:
923924
# Build normalized 16x16 Hadamard via Sylvester construction

bitsandbytes/functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,12 +1237,12 @@ def dequantize_nvfp4(
12371237
)
12381238

12391239
if quant_state.rotated:
1240-
# Undo rotation: data was quantized as x @ B^T, so recover x = out @ B.
1241-
# B is the cached randomized Hadamard matrix (orthogonal, so B^T·B = I).
1240+
# Undo rotation: the CUTLASS GEMM computes x @ R (no transpose on R),
1241+
# so dequant gives approx x @ R. To recover x, multiply by R^{-1} = R^T.
12421242
from bitsandbytes.backends.cuda.ops import _get_rotation_matrix
12431243

1244-
B = _get_rotation_matrix(out.device)
1245-
out = (out.view(-1, 16) @ B).view(-1)
1244+
R = _get_rotation_matrix(out.device)
1245+
out = (out.view(-1, 16) @ R.T).view(-1)
12461246

12471247
return out.reshape(quant_state.shape)
12481248

0 commit comments

Comments
 (0)