@@ -1159,15 +1159,16 @@ def _has_cutlass_fused_quantize() -> bool:
11591159def quantize_nvfp4 (
11601160 A : torch .Tensor ,
11611161 tensor_scale : Optional [float ] = None ,
1162- rotate : bool = True ,
11631162) -> tuple [torch .Tensor , NVFP4QuantState ]:
1164- """Quantize a tensor to NVFP4 (E2M1) format.
1163+ """Quantize a tensor to NVFP4 (E2M1) format with Hadamard rotation.
1164+
1165+ Always applies a randomized 16x16 Hadamard rotation before quantization.
1166+ When CUTLASS is available (SM_120+), the rotation is fused into the
1167+ quantize kernel at zero cost. Otherwise, falls back to hand-written kernels.
11651168
11661169 Args:
11671170 A: Input tensor (float16, bfloat16, or float32). Must have numel divisible by 16.
11681171 tensor_scale: Optional pre-computed tensor scale. If None, computed as abs(max(A)).
1169- rotate: If True, apply Hadamard rotation before quantization (fused kernel).
1170- Default is True since the CUTLASS fused quantize includes rotation for free.
11711172
11721173 Returns:
11731174 Tuple of (packed_data, NVFP4QuantState).
@@ -1176,36 +1177,21 @@ def quantize_nvfp4(
11761177 input_dtype = A .dtype
11771178 A_flat = A .reshape (- 1 ).contiguous ()
11781179
1179- # Use CUTLASS fused quantize when available (7-9x faster)
1180+ # Use CUTLASS fused quantize when available (7-9x faster, rotation is free )
11801181 use_cutlass = _has_cutlass_fused_quantize () and A .is_cuda
11811182 if use_cutlass :
11821183 # CUTLASS fused quantize requires BF16 input
1183- if A_flat .dtype != torch .bfloat16 :
1184- A_bf16 = A_flat .to (torch .bfloat16 )
1185- else :
1186- A_bf16 = A_flat
1184+ A_bf16 = A_flat .to (torch .bfloat16 ) if A_flat .dtype != torch .bfloat16 else A_flat
11871185
1188- # Compute tensor_scale if not provided
11891186 if tensor_scale is None :
1190- if rotate :
1191- # For rotation, scale should be computed on rotated data.
1192- # The CUTLASS kernel handles this internally, but we need the
1193- # tensor_scale for the quantize op. Compute on original data
1194- # as approximation — the block scales handle per-block normalization.
1195- tensor_scale = A_bf16 .abs ().max ().item ()
1196- else :
1197- tensor_scale = A_bf16 .abs ().max ().item ()
1198-
1199- from bitsandbytes .backends .cuda .ops import _get_fused_quant_matrix
1187+ tensor_scale = A_bf16 .abs ().max ().item ()
12001188
1201- B = _get_fused_quant_matrix (A .device , quest = rotate )
1202- packed , block_scales , ts = torch .ops .bitsandbytes .cutlass_fused_quantize_nvfp4 (A_bf16 , B , tensor_scale , rotate )
1203- elif rotate :
1204- if tensor_scale is None :
1205- tensor_scale = None # let the kernel compute it
1206- packed , block_scales , ts = torch .ops .bitsandbytes .fused_hadamard_quantize_nvfp4 (A_flat , tensor_scale )
1189+ packed , block_scales , ts = torch .ops .bitsandbytes .cutlass_fused_quantize_nvfp4 (A_bf16 , tensor_scale )
12071190 else :
1208- packed , block_scales , ts = torch .ops .bitsandbytes .quantize_nvfp4 (A_flat , tensor_scale )
1191+ # Fallback: hand-written fused Hadamard + NVFP4 quantize kernel.
1192+ # Note: uses plain (non-randomized) Had16. Dequantize inverse rotation
1193+ # will be slightly off but this path is only for non-SM_120+ development.
1194+ packed , block_scales , ts = torch .ops .bitsandbytes .fused_hadamard_quantize_nvfp4 (A_flat , tensor_scale )
12091195
12101196 # Pre-compute CUTLASS block-scaled layout for GEMM. The 2D scale shape is
12111197 # (rows, K//16) where rows is the product of all dims except the last.
@@ -1220,7 +1206,7 @@ def quantize_nvfp4(
12201206 tensor_scale = ts .item (),
12211207 shape = input_shape ,
12221208 dtype = input_dtype ,
1223- rotated = rotate ,
1209+ rotated = True ,
12241210 block_scales_blocked = block_scales_blocked ,
12251211 )
12261212 return packed , state
@@ -1251,8 +1237,12 @@ def dequantize_nvfp4(
12511237 )
12521238
12531239 if quant_state .rotated :
1254- # Apply inverse Hadamard rotation
1255- torch .ops .bitsandbytes .hadamard_rotate_nvfp4 (out )
1240+ # Undo rotation: data was quantized as x @ B^T, so recover x = out @ B.
1241+ # B is the cached randomized Hadamard matrix (orthogonal, so B^T·B = I).
1242+ from bitsandbytes .backends .cuda .ops import _get_rotation_matrix
1243+
1244+ B = _get_rotation_matrix (out .device )
1245+ out = (out .view (- 1 , 16 ) @ B ).view (- 1 )
12561246
12571247 return out .reshape (quant_state .shape )
12581248
0 commit comments