Skip to content

Commit cbe89a9

Browse files
TimDettmersclaude
andcommitted
test: Add NVFP4 quantization kernel test suite
Tests E2M1 encoding table, round-trip error bounds, two-level scaling, Hadamard orthogonality/norm preservation/kurtosis reduction, and fused rotate+quantize correctness. Uses ctypes to call C kernels directly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dd6b88c commit cbe89a9

File tree

1 file changed

+307
-0
lines changed

1 file changed

+307
-0
lines changed

tests/test_nvfp4.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
"""Tests for NVFP4 (E2M1) quantization kernels.
2+
3+
Tests the NVFP4 quantize/dequantize, Hadamard rotation, and fused
4+
rotate+quantize kernels via ctypes calls to the C library.
5+
"""
6+
7+
import ctypes
8+
import os
9+
10+
import pytest
11+
import torch
12+
13+
14+
def get_lib():
15+
"""Load the bitsandbytes CUDA library."""
16+
lib_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "bitsandbytes")
17+
# Try cuda131 first (built from nvcc 13.1), fall back to cuda130
18+
for suffix in ["cuda131", "cuda130"]:
19+
lib_path = os.path.join(lib_dir, f"libbitsandbytes_{suffix}.so")
20+
if os.path.exists(lib_path):
21+
return ctypes.cdll.LoadLibrary(lib_path)
22+
raise RuntimeError(f"Could not find bitsandbytes CUDA library in {lib_dir}")
23+
24+
25+
def quantize_nvfp4(x, tensor_scale=None):
26+
"""Quantize a FP16/BF16/FP32 tensor to NVFP4 using the C kernel."""
27+
lib = get_lib()
28+
n = x.numel()
29+
assert n % 16 == 0, "NVFP4 requires tensor size divisible by 16"
30+
31+
if tensor_scale is None:
32+
tensor_scale = x.abs().max().item()
33+
34+
packed = torch.zeros(n // 2, dtype=torch.uint8, device=x.device)
35+
block_scales = torch.zeros(n // 16, dtype=torch.uint8, device=x.device)
36+
37+
if x.dtype == torch.float16:
38+
func = lib.cquantize_nvfp4_fp16
39+
elif x.dtype == torch.bfloat16:
40+
func = lib.cquantize_nvfp4_bf16
41+
elif x.dtype == torch.float32:
42+
func = lib.cquantize_nvfp4_fp32
43+
else:
44+
raise ValueError(f"Unsupported dtype: {x.dtype}")
45+
46+
func(
47+
ctypes.c_void_p(x.data_ptr()),
48+
ctypes.c_void_p(packed.data_ptr()),
49+
ctypes.c_void_p(block_scales.data_ptr()),
50+
ctypes.c_float(tensor_scale),
51+
ctypes.c_int(n),
52+
)
53+
torch.cuda.synchronize()
54+
return packed, block_scales, tensor_scale
55+
56+
57+
def dequantize_nvfp4(packed, block_scales, tensor_scale, n, dtype=torch.float16):
58+
"""Dequantize NVFP4 packed data back to FP16/BF16/FP32."""
59+
lib = get_lib()
60+
output = torch.zeros(n, dtype=dtype, device=packed.device)
61+
62+
if dtype == torch.float16:
63+
func = lib.cdequantize_nvfp4_fp16
64+
elif dtype == torch.bfloat16:
65+
func = lib.cdequantize_nvfp4_bf16
66+
elif dtype == torch.float32:
67+
func = lib.cdequantize_nvfp4_fp32
68+
else:
69+
raise ValueError(f"Unsupported dtype: {dtype}")
70+
71+
func(
72+
ctypes.c_void_p(packed.data_ptr()),
73+
ctypes.c_void_p(block_scales.data_ptr()),
74+
ctypes.c_float(tensor_scale),
75+
ctypes.c_void_p(output.data_ptr()),
76+
ctypes.c_int(n),
77+
ctypes.c_void_p(0), # default stream
78+
)
79+
torch.cuda.synchronize()
80+
return output
81+
82+
83+
def hadamard_rotate16(x):
84+
"""Apply block-diagonal Had16 rotation in-place."""
85+
lib = get_lib()
86+
n = x.numel()
87+
assert n % 16 == 0, "Hadamard rotation requires size divisible by 16"
88+
89+
if x.dtype == torch.float16:
90+
func = lib.chadamard_rotate16_fp16
91+
elif x.dtype == torch.bfloat16:
92+
func = lib.chadamard_rotate16_bf16
93+
elif x.dtype == torch.float32:
94+
func = lib.chadamard_rotate16_fp32
95+
else:
96+
raise ValueError(f"Unsupported dtype: {x.dtype}")
97+
98+
func(ctypes.c_void_p(x.data_ptr()), ctypes.c_int(n))
99+
torch.cuda.synchronize()
100+
101+
102+
def fused_hadamard_quantize_nvfp4(x, tensor_scale=None):
103+
"""Fused Hadamard rotation + NVFP4 quantization."""
104+
lib = get_lib()
105+
n = x.numel()
106+
assert n % 16 == 0
107+
108+
if tensor_scale is None:
109+
# Need to compute tensor_scale on rotated data
110+
# Apply rotation to a copy to get the scale
111+
x_copy = x.clone()
112+
hadamard_rotate16(x_copy)
113+
tensor_scale = x_copy.abs().max().item()
114+
115+
packed = torch.zeros(n // 2, dtype=torch.uint8, device=x.device)
116+
block_scales = torch.zeros(n // 16, dtype=torch.uint8, device=x.device)
117+
118+
if x.dtype == torch.float16:
119+
func = lib.cfused_hadamard_quantize_nvfp4_fp16
120+
elif x.dtype == torch.bfloat16:
121+
func = lib.cfused_hadamard_quantize_nvfp4_bf16
122+
elif x.dtype == torch.float32:
123+
func = lib.cfused_hadamard_quantize_nvfp4_fp32
124+
else:
125+
raise ValueError(f"Unsupported dtype: {x.dtype}")
126+
127+
func(
128+
ctypes.c_void_p(x.data_ptr()),
129+
ctypes.c_void_p(packed.data_ptr()),
130+
ctypes.c_void_p(block_scales.data_ptr()),
131+
ctypes.c_float(tensor_scale),
132+
ctypes.c_int(n),
133+
)
134+
torch.cuda.synchronize()
135+
return packed, block_scales, tensor_scale
136+
137+
138+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
139+
class TestNVFP4Encoding:
140+
"""Test the E2M1 encoding table and basic quantization."""
141+
142+
def test_nvfp4_encoding_table(self):
143+
"""Verify all 16 E2M1 codes produce correct values via round-trip."""
144+
# E2M1 representable magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}
145+
test_vals = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]
146+
x = torch.tensor(test_vals, dtype=torch.float16, device="cuda")
147+
148+
# tensor_scale = 1.0 so block_scale = max(6)/6 = 1.0 (exactly E4M3)
149+
packed, scales, ts = quantize_nvfp4(x, tensor_scale=1.0)
150+
y = dequantize_nvfp4(packed, scales, ts, len(test_vals))
151+
152+
for i, (inp, out) in enumerate(zip(test_vals, y.tolist())):
153+
assert abs(inp - out) < 0.01, f"E2M1 code {i}: expected {inp}, got {out}"
154+
155+
def test_nvfp4_round_trip_error(self):
156+
"""Verify round-trip error is within expected E2M1 bounds."""
157+
torch.manual_seed(42)
158+
n = 1024 * 16 # Multiple of 16
159+
x = torch.randn(n, dtype=torch.float16, device="cuda")
160+
161+
packed, scales, ts = quantize_nvfp4(x)
162+
y = dequantize_nvfp4(packed, scales, ts, n)
163+
164+
err = (x.float() - y.float()).abs()
165+
mean_err = err.mean().item()
166+
# E2M1 with blocksize 16 on standard normal data should have
167+
# mean absolute error roughly 0.05-0.10
168+
assert mean_err < 0.15, f"Mean abs error {mean_err:.4f} exceeds bound 0.15"
169+
assert mean_err > 0.01, f"Mean abs error {mean_err:.4f} suspiciously low"
170+
171+
def test_nvfp4_two_level_scaling(self):
172+
"""Verify tensor scale + block scale correctly recovers large values."""
173+
# Create data with values outside [-6, 6]
174+
torch.manual_seed(42)
175+
n = 256
176+
x = torch.randn(n, dtype=torch.float16, device="cuda") * 100.0
177+
178+
packed, scales, ts = quantize_nvfp4(x)
179+
y = dequantize_nvfp4(packed, scales, ts, n)
180+
181+
# Output should have roughly the same range as input
182+
assert y.abs().max().item() > 50.0, "Two-level scaling failed to preserve large magnitudes"
183+
184+
# Relative error should be bounded
185+
mask = x.abs() > 10.0
186+
if mask.sum() > 0:
187+
rel_err = ((x[mask].float() - y[mask].float()).abs() / x[mask].abs().float()).mean().item()
188+
assert rel_err < 0.5, f"Relative error on large values: {rel_err:.4f}"
189+
190+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
191+
def test_nvfp4_dtypes(self, dtype):
192+
"""Verify quantization works for FP16 and BF16."""
193+
torch.manual_seed(42)
194+
n = 1024
195+
x = torch.randn(n, dtype=dtype, device="cuda")
196+
197+
packed, scales, ts = quantize_nvfp4(x)
198+
y = dequantize_nvfp4(packed, scales, ts, n, dtype=dtype)
199+
200+
assert y.dtype == dtype
201+
err = (x.float() - y.float()).abs().mean().item()
202+
assert err < 0.15
203+
204+
205+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
206+
class TestHadamardRotation:
207+
"""Test the block-diagonal Had16 rotation kernel."""
208+
209+
def test_hadamard_orthogonality(self):
210+
"""Applying Hadamard twice should return the original (H*H^T = I)."""
211+
torch.manual_seed(42)
212+
n = 1024
213+
x = torch.randn(n, dtype=torch.float16, device="cuda")
214+
x_orig = x.clone()
215+
216+
hadamard_rotate16(x)
217+
hadamard_rotate16(x)
218+
219+
err = (x.float() - x_orig.float()).abs().max().item()
220+
assert err < 0.01, f"Double rotation max error {err:.6f} exceeds FP16 tolerance"
221+
222+
def test_hadamard_reduces_kurtosis(self):
223+
"""Hadamard rotation should make Laplace-distributed data more Gaussian."""
224+
torch.manual_seed(123)
225+
n = 4096
226+
227+
# Generate Laplace distribution (kurtosis ~6)
228+
e1 = torch.empty(n, device="cuda").exponential_(1.0)
229+
e2 = torch.empty(n, device="cuda").exponential_(1.0)
230+
lap = (e1 - e2).half()
231+
232+
def kurtosis(t):
233+
t = t.float()
234+
m = t.mean()
235+
return ((t - m) ** 4).mean() / ((t - m) ** 2).mean() ** 2
236+
237+
kurt_before = kurtosis(lap).item()
238+
239+
lap_rot = lap.clone()
240+
hadamard_rotate16(lap_rot)
241+
242+
kurt_after = kurtosis(lap_rot).item()
243+
244+
assert kurt_after < kurt_before, f"Kurtosis increased: {kurt_before:.2f} -> {kurt_after:.2f}"
245+
# After rotation, kurtosis should be closer to 3 (Gaussian)
246+
assert kurt_after < 4.0, f"Post-rotation kurtosis {kurt_after:.2f} too high (expected < 4.0)"
247+
248+
def test_hadamard_preserves_norm(self):
249+
"""Hadamard rotation should preserve L2 norm (orthogonal transform)."""
250+
torch.manual_seed(42)
251+
n = 1024
252+
x = torch.randn(n, dtype=torch.float32, device="cuda")
253+
norm_before = x.norm().item()
254+
255+
hadamard_rotate16(x)
256+
norm_after = x.norm().item()
257+
258+
rel_err = abs(norm_before - norm_after) / norm_before
259+
assert rel_err < 0.001, f"Norm changed by {rel_err:.6f}"
260+
261+
262+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
263+
class TestFusedHadamardQuantize:
264+
"""Test the fused Had16 + NVFP4 quantize kernel."""
265+
266+
def test_fused_matches_sequential(self):
267+
"""Fused kernel output should match sequential rotate+quantize."""
268+
torch.manual_seed(42)
269+
n = 1024
270+
x = torch.randn(n, dtype=torch.float16, device="cuda")
271+
272+
# Sequential: rotate, then quantize
273+
x_seq = x.clone()
274+
hadamard_rotate16(x_seq)
275+
ts = x_seq.abs().max().item()
276+
packed_seq, scales_seq, _ = quantize_nvfp4(x_seq, tensor_scale=ts)
277+
278+
# Fused: single kernel
279+
packed_fused, scales_fused, _ = fused_hadamard_quantize_nvfp4(x.clone(), tensor_scale=ts)
280+
281+
assert torch.equal(packed_seq, packed_fused), "Packed data mismatch"
282+
assert torch.equal(scales_seq, scales_fused), "Block scales mismatch"
283+
284+
def test_fused_reduces_quantization_error(self):
285+
"""Rotation before quantization should reduce error on Laplace data."""
286+
torch.manual_seed(42)
287+
n = 4096
288+
289+
# Laplace-distributed data (outlier-heavy)
290+
e1 = torch.empty(n, device="cuda").exponential_(1.0)
291+
e2 = torch.empty(n, device="cuda").exponential_(1.0)
292+
x = (e1 - e2).half()
293+
294+
# Without rotation
295+
packed_nr, scales_nr, ts_nr = quantize_nvfp4(x)
296+
y_nr = dequantize_nvfp4(packed_nr, scales_nr, ts_nr, n)
297+
err_no_rot = (x.float() - y_nr.float()).abs().mean().item()
298+
299+
# With rotation (fused)
300+
packed_r, scales_r, ts_r = fused_hadamard_quantize_nvfp4(x)
301+
y_r = dequantize_nvfp4(packed_r, scales_r, ts_r, n)
302+
# Need to apply rotation to the dequantized output for fair comparison
303+
hadamard_rotate16(y_r) # Inverse rotation
304+
err_rot = (x.float() - y_r.float()).abs().mean().item()
305+
306+
# Rotation should reduce error on Laplace data
307+
assert err_rot < err_no_rot, f"Rotation error {err_rot:.4f} >= no-rotation error {err_no_rot:.4f}"

0 commit comments

Comments
 (0)