|
| 1 | +"""Tests for NVFP4 (E2M1) quantization kernels. |
| 2 | +
|
| 3 | +Tests the NVFP4 quantize/dequantize, Hadamard rotation, and fused |
| 4 | +rotate+quantize kernels via ctypes calls to the C library. |
| 5 | +""" |
| 6 | + |
| 7 | +import ctypes |
| 8 | +import os |
| 9 | + |
| 10 | +import pytest |
| 11 | +import torch |
| 12 | + |
| 13 | + |
| 14 | +def get_lib(): |
| 15 | + """Load the bitsandbytes CUDA library.""" |
| 16 | + lib_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "bitsandbytes") |
| 17 | + # Try cuda131 first (built from nvcc 13.1), fall back to cuda130 |
| 18 | + for suffix in ["cuda131", "cuda130"]: |
| 19 | + lib_path = os.path.join(lib_dir, f"libbitsandbytes_{suffix}.so") |
| 20 | + if os.path.exists(lib_path): |
| 21 | + return ctypes.cdll.LoadLibrary(lib_path) |
| 22 | + raise RuntimeError(f"Could not find bitsandbytes CUDA library in {lib_dir}") |
| 23 | + |
| 24 | + |
| 25 | +def quantize_nvfp4(x, tensor_scale=None): |
| 26 | + """Quantize a FP16/BF16/FP32 tensor to NVFP4 using the C kernel.""" |
| 27 | + lib = get_lib() |
| 28 | + n = x.numel() |
| 29 | + assert n % 16 == 0, "NVFP4 requires tensor size divisible by 16" |
| 30 | + |
| 31 | + if tensor_scale is None: |
| 32 | + tensor_scale = x.abs().max().item() |
| 33 | + |
| 34 | + packed = torch.zeros(n // 2, dtype=torch.uint8, device=x.device) |
| 35 | + block_scales = torch.zeros(n // 16, dtype=torch.uint8, device=x.device) |
| 36 | + |
| 37 | + if x.dtype == torch.float16: |
| 38 | + func = lib.cquantize_nvfp4_fp16 |
| 39 | + elif x.dtype == torch.bfloat16: |
| 40 | + func = lib.cquantize_nvfp4_bf16 |
| 41 | + elif x.dtype == torch.float32: |
| 42 | + func = lib.cquantize_nvfp4_fp32 |
| 43 | + else: |
| 44 | + raise ValueError(f"Unsupported dtype: {x.dtype}") |
| 45 | + |
| 46 | + func( |
| 47 | + ctypes.c_void_p(x.data_ptr()), |
| 48 | + ctypes.c_void_p(packed.data_ptr()), |
| 49 | + ctypes.c_void_p(block_scales.data_ptr()), |
| 50 | + ctypes.c_float(tensor_scale), |
| 51 | + ctypes.c_int(n), |
| 52 | + ) |
| 53 | + torch.cuda.synchronize() |
| 54 | + return packed, block_scales, tensor_scale |
| 55 | + |
| 56 | + |
| 57 | +def dequantize_nvfp4(packed, block_scales, tensor_scale, n, dtype=torch.float16): |
| 58 | + """Dequantize NVFP4 packed data back to FP16/BF16/FP32.""" |
| 59 | + lib = get_lib() |
| 60 | + output = torch.zeros(n, dtype=dtype, device=packed.device) |
| 61 | + |
| 62 | + if dtype == torch.float16: |
| 63 | + func = lib.cdequantize_nvfp4_fp16 |
| 64 | + elif dtype == torch.bfloat16: |
| 65 | + func = lib.cdequantize_nvfp4_bf16 |
| 66 | + elif dtype == torch.float32: |
| 67 | + func = lib.cdequantize_nvfp4_fp32 |
| 68 | + else: |
| 69 | + raise ValueError(f"Unsupported dtype: {dtype}") |
| 70 | + |
| 71 | + func( |
| 72 | + ctypes.c_void_p(packed.data_ptr()), |
| 73 | + ctypes.c_void_p(block_scales.data_ptr()), |
| 74 | + ctypes.c_float(tensor_scale), |
| 75 | + ctypes.c_void_p(output.data_ptr()), |
| 76 | + ctypes.c_int(n), |
| 77 | + ctypes.c_void_p(0), # default stream |
| 78 | + ) |
| 79 | + torch.cuda.synchronize() |
| 80 | + return output |
| 81 | + |
| 82 | + |
| 83 | +def hadamard_rotate16(x): |
| 84 | + """Apply block-diagonal Had16 rotation in-place.""" |
| 85 | + lib = get_lib() |
| 86 | + n = x.numel() |
| 87 | + assert n % 16 == 0, "Hadamard rotation requires size divisible by 16" |
| 88 | + |
| 89 | + if x.dtype == torch.float16: |
| 90 | + func = lib.chadamard_rotate16_fp16 |
| 91 | + elif x.dtype == torch.bfloat16: |
| 92 | + func = lib.chadamard_rotate16_bf16 |
| 93 | + elif x.dtype == torch.float32: |
| 94 | + func = lib.chadamard_rotate16_fp32 |
| 95 | + else: |
| 96 | + raise ValueError(f"Unsupported dtype: {x.dtype}") |
| 97 | + |
| 98 | + func(ctypes.c_void_p(x.data_ptr()), ctypes.c_int(n)) |
| 99 | + torch.cuda.synchronize() |
| 100 | + |
| 101 | + |
| 102 | +def fused_hadamard_quantize_nvfp4(x, tensor_scale=None): |
| 103 | + """Fused Hadamard rotation + NVFP4 quantization.""" |
| 104 | + lib = get_lib() |
| 105 | + n = x.numel() |
| 106 | + assert n % 16 == 0 |
| 107 | + |
| 108 | + if tensor_scale is None: |
| 109 | + # Need to compute tensor_scale on rotated data |
| 110 | + # Apply rotation to a copy to get the scale |
| 111 | + x_copy = x.clone() |
| 112 | + hadamard_rotate16(x_copy) |
| 113 | + tensor_scale = x_copy.abs().max().item() |
| 114 | + |
| 115 | + packed = torch.zeros(n // 2, dtype=torch.uint8, device=x.device) |
| 116 | + block_scales = torch.zeros(n // 16, dtype=torch.uint8, device=x.device) |
| 117 | + |
| 118 | + if x.dtype == torch.float16: |
| 119 | + func = lib.cfused_hadamard_quantize_nvfp4_fp16 |
| 120 | + elif x.dtype == torch.bfloat16: |
| 121 | + func = lib.cfused_hadamard_quantize_nvfp4_bf16 |
| 122 | + elif x.dtype == torch.float32: |
| 123 | + func = lib.cfused_hadamard_quantize_nvfp4_fp32 |
| 124 | + else: |
| 125 | + raise ValueError(f"Unsupported dtype: {x.dtype}") |
| 126 | + |
| 127 | + func( |
| 128 | + ctypes.c_void_p(x.data_ptr()), |
| 129 | + ctypes.c_void_p(packed.data_ptr()), |
| 130 | + ctypes.c_void_p(block_scales.data_ptr()), |
| 131 | + ctypes.c_float(tensor_scale), |
| 132 | + ctypes.c_int(n), |
| 133 | + ) |
| 134 | + torch.cuda.synchronize() |
| 135 | + return packed, block_scales, tensor_scale |
| 136 | + |
| 137 | + |
| 138 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 139 | +class TestNVFP4Encoding: |
| 140 | + """Test the E2M1 encoding table and basic quantization.""" |
| 141 | + |
| 142 | + def test_nvfp4_encoding_table(self): |
| 143 | + """Verify all 16 E2M1 codes produce correct values via round-trip.""" |
| 144 | + # E2M1 representable magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0} |
| 145 | + test_vals = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] |
| 146 | + x = torch.tensor(test_vals, dtype=torch.float16, device="cuda") |
| 147 | + |
| 148 | + # tensor_scale = 1.0 so block_scale = max(6)/6 = 1.0 (exactly E4M3) |
| 149 | + packed, scales, ts = quantize_nvfp4(x, tensor_scale=1.0) |
| 150 | + y = dequantize_nvfp4(packed, scales, ts, len(test_vals)) |
| 151 | + |
| 152 | + for i, (inp, out) in enumerate(zip(test_vals, y.tolist())): |
| 153 | + assert abs(inp - out) < 0.01, f"E2M1 code {i}: expected {inp}, got {out}" |
| 154 | + |
| 155 | + def test_nvfp4_round_trip_error(self): |
| 156 | + """Verify round-trip error is within expected E2M1 bounds.""" |
| 157 | + torch.manual_seed(42) |
| 158 | + n = 1024 * 16 # Multiple of 16 |
| 159 | + x = torch.randn(n, dtype=torch.float16, device="cuda") |
| 160 | + |
| 161 | + packed, scales, ts = quantize_nvfp4(x) |
| 162 | + y = dequantize_nvfp4(packed, scales, ts, n) |
| 163 | + |
| 164 | + err = (x.float() - y.float()).abs() |
| 165 | + mean_err = err.mean().item() |
| 166 | + # E2M1 with blocksize 16 on standard normal data should have |
| 167 | + # mean absolute error roughly 0.05-0.10 |
| 168 | + assert mean_err < 0.15, f"Mean abs error {mean_err:.4f} exceeds bound 0.15" |
| 169 | + assert mean_err > 0.01, f"Mean abs error {mean_err:.4f} suspiciously low" |
| 170 | + |
| 171 | + def test_nvfp4_two_level_scaling(self): |
| 172 | + """Verify tensor scale + block scale correctly recovers large values.""" |
| 173 | + # Create data with values outside [-6, 6] |
| 174 | + torch.manual_seed(42) |
| 175 | + n = 256 |
| 176 | + x = torch.randn(n, dtype=torch.float16, device="cuda") * 100.0 |
| 177 | + |
| 178 | + packed, scales, ts = quantize_nvfp4(x) |
| 179 | + y = dequantize_nvfp4(packed, scales, ts, n) |
| 180 | + |
| 181 | + # Output should have roughly the same range as input |
| 182 | + assert y.abs().max().item() > 50.0, "Two-level scaling failed to preserve large magnitudes" |
| 183 | + |
| 184 | + # Relative error should be bounded |
| 185 | + mask = x.abs() > 10.0 |
| 186 | + if mask.sum() > 0: |
| 187 | + rel_err = ((x[mask].float() - y[mask].float()).abs() / x[mask].abs().float()).mean().item() |
| 188 | + assert rel_err < 0.5, f"Relative error on large values: {rel_err:.4f}" |
| 189 | + |
| 190 | + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) |
| 191 | + def test_nvfp4_dtypes(self, dtype): |
| 192 | + """Verify quantization works for FP16 and BF16.""" |
| 193 | + torch.manual_seed(42) |
| 194 | + n = 1024 |
| 195 | + x = torch.randn(n, dtype=dtype, device="cuda") |
| 196 | + |
| 197 | + packed, scales, ts = quantize_nvfp4(x) |
| 198 | + y = dequantize_nvfp4(packed, scales, ts, n, dtype=dtype) |
| 199 | + |
| 200 | + assert y.dtype == dtype |
| 201 | + err = (x.float() - y.float()).abs().mean().item() |
| 202 | + assert err < 0.15 |
| 203 | + |
| 204 | + |
| 205 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 206 | +class TestHadamardRotation: |
| 207 | + """Test the block-diagonal Had16 rotation kernel.""" |
| 208 | + |
| 209 | + def test_hadamard_orthogonality(self): |
| 210 | + """Applying Hadamard twice should return the original (H*H^T = I).""" |
| 211 | + torch.manual_seed(42) |
| 212 | + n = 1024 |
| 213 | + x = torch.randn(n, dtype=torch.float16, device="cuda") |
| 214 | + x_orig = x.clone() |
| 215 | + |
| 216 | + hadamard_rotate16(x) |
| 217 | + hadamard_rotate16(x) |
| 218 | + |
| 219 | + err = (x.float() - x_orig.float()).abs().max().item() |
| 220 | + assert err < 0.01, f"Double rotation max error {err:.6f} exceeds FP16 tolerance" |
| 221 | + |
| 222 | + def test_hadamard_reduces_kurtosis(self): |
| 223 | + """Hadamard rotation should make Laplace-distributed data more Gaussian.""" |
| 224 | + torch.manual_seed(123) |
| 225 | + n = 4096 |
| 226 | + |
| 227 | + # Generate Laplace distribution (kurtosis ~6) |
| 228 | + e1 = torch.empty(n, device="cuda").exponential_(1.0) |
| 229 | + e2 = torch.empty(n, device="cuda").exponential_(1.0) |
| 230 | + lap = (e1 - e2).half() |
| 231 | + |
| 232 | + def kurtosis(t): |
| 233 | + t = t.float() |
| 234 | + m = t.mean() |
| 235 | + return ((t - m) ** 4).mean() / ((t - m) ** 2).mean() ** 2 |
| 236 | + |
| 237 | + kurt_before = kurtosis(lap).item() |
| 238 | + |
| 239 | + lap_rot = lap.clone() |
| 240 | + hadamard_rotate16(lap_rot) |
| 241 | + |
| 242 | + kurt_after = kurtosis(lap_rot).item() |
| 243 | + |
| 244 | + assert kurt_after < kurt_before, f"Kurtosis increased: {kurt_before:.2f} -> {kurt_after:.2f}" |
| 245 | + # After rotation, kurtosis should be closer to 3 (Gaussian) |
| 246 | + assert kurt_after < 4.0, f"Post-rotation kurtosis {kurt_after:.2f} too high (expected < 4.0)" |
| 247 | + |
| 248 | + def test_hadamard_preserves_norm(self): |
| 249 | + """Hadamard rotation should preserve L2 norm (orthogonal transform).""" |
| 250 | + torch.manual_seed(42) |
| 251 | + n = 1024 |
| 252 | + x = torch.randn(n, dtype=torch.float32, device="cuda") |
| 253 | + norm_before = x.norm().item() |
| 254 | + |
| 255 | + hadamard_rotate16(x) |
| 256 | + norm_after = x.norm().item() |
| 257 | + |
| 258 | + rel_err = abs(norm_before - norm_after) / norm_before |
| 259 | + assert rel_err < 0.001, f"Norm changed by {rel_err:.6f}" |
| 260 | + |
| 261 | + |
| 262 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 263 | +class TestFusedHadamardQuantize: |
| 264 | + """Test the fused Had16 + NVFP4 quantize kernel.""" |
| 265 | + |
| 266 | + def test_fused_matches_sequential(self): |
| 267 | + """Fused kernel output should match sequential rotate+quantize.""" |
| 268 | + torch.manual_seed(42) |
| 269 | + n = 1024 |
| 270 | + x = torch.randn(n, dtype=torch.float16, device="cuda") |
| 271 | + |
| 272 | + # Sequential: rotate, then quantize |
| 273 | + x_seq = x.clone() |
| 274 | + hadamard_rotate16(x_seq) |
| 275 | + ts = x_seq.abs().max().item() |
| 276 | + packed_seq, scales_seq, _ = quantize_nvfp4(x_seq, tensor_scale=ts) |
| 277 | + |
| 278 | + # Fused: single kernel |
| 279 | + packed_fused, scales_fused, _ = fused_hadamard_quantize_nvfp4(x.clone(), tensor_scale=ts) |
| 280 | + |
| 281 | + assert torch.equal(packed_seq, packed_fused), "Packed data mismatch" |
| 282 | + assert torch.equal(scales_seq, scales_fused), "Block scales mismatch" |
| 283 | + |
| 284 | + def test_fused_reduces_quantization_error(self): |
| 285 | + """Rotation before quantization should reduce error on Laplace data.""" |
| 286 | + torch.manual_seed(42) |
| 287 | + n = 4096 |
| 288 | + |
| 289 | + # Laplace-distributed data (outlier-heavy) |
| 290 | + e1 = torch.empty(n, device="cuda").exponential_(1.0) |
| 291 | + e2 = torch.empty(n, device="cuda").exponential_(1.0) |
| 292 | + x = (e1 - e2).half() |
| 293 | + |
| 294 | + # Without rotation |
| 295 | + packed_nr, scales_nr, ts_nr = quantize_nvfp4(x) |
| 296 | + y_nr = dequantize_nvfp4(packed_nr, scales_nr, ts_nr, n) |
| 297 | + err_no_rot = (x.float() - y_nr.float()).abs().mean().item() |
| 298 | + |
| 299 | + # With rotation (fused) |
| 300 | + packed_r, scales_r, ts_r = fused_hadamard_quantize_nvfp4(x) |
| 301 | + y_r = dequantize_nvfp4(packed_r, scales_r, ts_r, n) |
| 302 | + # Need to apply rotation to the dequantized output for fair comparison |
| 303 | + hadamard_rotate16(y_r) # Inverse rotation |
| 304 | + err_rot = (x.float() - y_r.float()).abs().mean().item() |
| 305 | + |
| 306 | + # Rotation should reduce error on Laplace data |
| 307 | + assert err_rot < err_no_rot, f"Rotation error {err_rot:.4f} >= no-rotation error {err_no_rot:.4f}" |
0 commit comments