Skip to content

Commit 92dc4ee

Browse files
TimDettmersclaude
andcommitted
test: Add NVFP4 GEMM kernel correctness tests
Tests for the simple GEMM kernel: identity scales (all 1s), multi-K-tile accumulation, and random data verification. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5162511 commit 92dc4ee

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed

tests/test_gemm_nvfp4.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
"""Test NVFP4 GEMM kernel on SM_120 (Blackwell consumer GPUs).
2+
3+
Tests the block-scaled mma.sync GEMM kernel via ctypes.
4+
"""
5+
6+
import ctypes
7+
import os
8+
9+
import pytest
10+
import torch
11+
12+
13+
def get_lib():
14+
"""Load the bitsandbytes CUDA library."""
15+
lib_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "bitsandbytes")
16+
for suffix in ["cuda131", "cuda130"]:
17+
lib_path = os.path.join(lib_dir, f"libbitsandbytes_{suffix}.so")
18+
if os.path.exists(lib_path):
19+
return ctypes.cdll.LoadLibrary(lib_path)
20+
raise RuntimeError(f"Could not find bitsandbytes CUDA library in {lib_dir}")
21+
22+
23+
# E2M1 representable magnitudes (unsigned)
24+
E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
25+
26+
27+
def float_to_e2m1(x):
28+
"""Quantize a float to nearest E2M1 value (magnitude only)."""
29+
ax = abs(x)
30+
# Decision boundaries (midpoints between consecutive E2M1 values)
31+
boundaries = [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]
32+
for i, b in enumerate(boundaries):
33+
if ax < b:
34+
return E2M1_VALUES[i] * (1 if x >= 0 else -1)
35+
return E2M1_VALUES[7] * (1 if x >= 0 else -1)
36+
37+
38+
def float_to_e4m3(x):
39+
"""Quantize a positive float to UE4M3 (unsigned E4M3, bias=7)."""
40+
if x <= 0:
41+
return 0, 0.0
42+
# Clamp to max representable value (~448)
43+
x = min(x, 448.0)
44+
if x == 0:
45+
return 0, 0.0
46+
# Find the exponent
47+
import math
48+
e = math.floor(math.log2(x))
49+
e = max(e, -6) # min exponent with bias=7 is -6
50+
e = min(e, 8) # max exponent with bias=7 is 8
51+
# Mantissa
52+
m = x / (2.0 ** e) - 1.0
53+
m = max(0, min(m, 0.875)) # 3 mantissa bits -> 7/8 max
54+
m_int = round(m * 8)
55+
m_int = min(m_int, 7)
56+
# Encode
57+
e_biased = e + 7
58+
e_biased = max(0, min(e_biased, 15))
59+
code = (e_biased << 3) | m_int
60+
# Decode to get actual value
61+
actual = (1.0 + m_int / 8.0) * (2.0 ** (e_biased - 7))
62+
if e_biased == 0:
63+
actual = m_int / 8.0 * (2.0 ** -6)
64+
return code, actual
65+
66+
67+
def quantize_tensor_reference(x_flat):
68+
"""Reference quantization: float tensor -> packed FP4 + block scales + tensor scale.
69+
70+
Returns (packed_bytes, block_scale_bytes, tensor_scale) in the format
71+
expected by the GEMM kernel.
72+
"""
73+
n = len(x_flat)
74+
assert n % 16 == 0
75+
num_blocks = n // 16
76+
77+
tensor_scale = max(abs(v) for v in x_flat)
78+
if tensor_scale == 0:
79+
tensor_scale = 1.0
80+
81+
packed = []
82+
block_scales = []
83+
84+
for b in range(num_blocks):
85+
block = x_flat[b * 16:(b + 1) * 16]
86+
# Normalize by tensor scale
87+
normalized = [v / tensor_scale for v in block]
88+
# Block absmax
89+
block_max = max(abs(v) for v in normalized)
90+
if block_max == 0:
91+
block_max = 1e-10
92+
93+
# Block scale = block_max / 6.0 (max E2M1 value)
94+
raw_scale = block_max / 6.0
95+
scale_code, scale_actual = float_to_e4m3(raw_scale)
96+
block_scales.append(scale_code)
97+
98+
if scale_actual == 0:
99+
scale_actual = 1e-10
100+
101+
# Quantize each element
102+
nibbles = []
103+
for v in normalized:
104+
scaled_v = v / scale_actual
105+
qval = float_to_e2m1(scaled_v)
106+
# Encode: sign in bit 3, magnitude in bits 0-2
107+
mag = abs(qval)
108+
mag_idx = E2M1_VALUES.index(mag) if mag in E2M1_VALUES else 0
109+
code = mag_idx
110+
if qval < 0:
111+
code |= 0x8
112+
nibbles.append(code)
113+
114+
# Pack 2 per byte (low nibble = even index, high nibble = odd index)
115+
for i in range(0, 16, 2):
116+
byte_val = (nibbles[i] & 0xF) | ((nibbles[i + 1] & 0xF) << 4)
117+
packed.append(byte_val)
118+
119+
return packed, block_scales, tensor_scale
120+
121+
122+
def dequantize_reference(packed, block_scales, tensor_scale, M, K):
123+
"""Reference dequantization for verification."""
124+
n = M * K
125+
result = []
126+
for i in range(n):
127+
byte_idx = i // 2
128+
block_idx = i // 16
129+
byte_val = packed[byte_idx]
130+
if i % 2 == 0:
131+
code = byte_val & 0xF
132+
else:
133+
code = (byte_val >> 4) & 0xF
134+
135+
sign = -1.0 if (code & 0x8) else 1.0
136+
mag_idx = code & 0x7
137+
mag = E2M1_VALUES[mag_idx]
138+
139+
# Decode block scale
140+
sf_code = block_scales[block_idx]
141+
sf_e = (sf_code >> 3) & 0xF
142+
sf_m = sf_code & 0x7
143+
if sf_e == 0:
144+
sf_val = sf_m / 8.0 * (2.0 ** -6)
145+
else:
146+
sf_val = (1.0 + sf_m / 8.0) * (2.0 ** (sf_e - 7))
147+
148+
result.append(sign * mag * sf_val * tensor_scale)
149+
return result
150+
151+
152+
def prepare_gemm_inputs(M, N, K, seed=42):
153+
"""Create random FP4-quantized inputs for GEMM testing.
154+
155+
Returns CUDA tensors ready for the GEMM kernel, plus reference
156+
dequantized matrices for verification.
157+
"""
158+
import random
159+
random.seed(seed)
160+
161+
# Generate random float values
162+
A_flat = [random.gauss(0, 1) for _ in range(M * K)]
163+
B_flat = [random.gauss(0, 1) for _ in range(N * K)] # B is N x K (transposed)
164+
165+
# Quantize
166+
A_packed, A_sf, A_ts = quantize_tensor_reference(A_flat)
167+
B_packed, B_sf, B_ts = quantize_tensor_reference(B_flat)
168+
169+
# Dequantize for reference
170+
A_deq = dequantize_reference(A_packed, A_sf, A_ts, M, K)
171+
B_deq = dequantize_reference(B_packed, B_sf, B_ts, N, K)
172+
173+
# Reshape for torch.matmul: A is M x K, B^T is N x K -> B is K x N
174+
A_ref = torch.tensor(A_deq, dtype=torch.float32).reshape(M, K)
175+
B_ref = torch.tensor(B_deq, dtype=torch.float32).reshape(N, K).T # K x N
176+
177+
# Reference output
178+
D_ref = A_ref @ B_ref # M x N
179+
180+
# Create CUDA tensors
181+
A_data = torch.tensor(A_packed, dtype=torch.uint8, device="cuda")
182+
B_data = torch.tensor(B_packed, dtype=torch.uint8, device="cuda")
183+
A_scales = torch.tensor(A_sf, dtype=torch.uint8, device="cuda")
184+
B_scales = torch.tensor(B_sf, dtype=torch.uint8, device="cuda")
185+
186+
return A_data, B_data, A_scales, B_scales, A_ts, B_ts, D_ref
187+
188+
189+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
190+
class TestGemmNVFP4:
191+
"""Test NVFP4 GEMM kernel correctness."""
192+
193+
def _run_gemm(self, M, N, K, seed=42):
194+
"""Run the GEMM kernel and return (output, reference)."""
195+
lib = get_lib()
196+
assert hasattr(lib, "cgemm_nvfp4"), "cgemm_nvfp4 symbol not found in library"
197+
198+
A_data, B_data, A_scales, B_scales, A_ts, B_ts, D_ref = prepare_gemm_inputs(M, N, K, seed)
199+
200+
D_out = torch.zeros(M, N, dtype=torch.float32, device="cuda")
201+
202+
lib.cgemm_nvfp4(
203+
ctypes.c_void_p(A_data.data_ptr()),
204+
ctypes.c_void_p(B_data.data_ptr()),
205+
ctypes.c_void_p(A_scales.data_ptr()),
206+
ctypes.c_void_p(B_scales.data_ptr()),
207+
ctypes.c_void_p(D_out.data_ptr()),
208+
ctypes.c_int(M),
209+
ctypes.c_int(N),
210+
ctypes.c_int(K),
211+
)
212+
torch.cuda.synchronize()
213+
214+
return D_out.cpu(), D_ref
215+
216+
def test_gemm_nvfp4_minimal(self):
217+
"""Test 16x8x64 (single MMA tile)."""
218+
D_out, D_ref = self._run_gemm(16, 8, 64)
219+
print(f"Output[0:4, 0:4]:\n{D_out[0:4, 0:4]}")
220+
print(f"Reference[0:4, 0:4]:\n{D_ref[0:4, 0:4]}")
221+
# Just check it runs and produces finite values
222+
assert torch.isfinite(D_out).all(), "Output contains non-finite values"
223+
# Check rough magnitude match (within 10x)
224+
if D_ref.abs().max() > 0:
225+
ratio = D_out.abs().max() / D_ref.abs().max()
226+
print(f"Max magnitude ratio (out/ref): {ratio:.3f}")
227+
228+
def test_gemm_nvfp4_identity_scales(self):
229+
"""Test with all-ones data and scale=1 to verify basic MMA correctness."""
230+
lib = get_lib()
231+
M, N, K = 16, 8, 64
232+
233+
# All values = 1.0 in E2M1: code = 0b0010 = 2
234+
# Pack: byte = (2) | (2 << 4) = 0x22
235+
A_packed = torch.full((M * K // 2,), 0x22, dtype=torch.uint8, device="cuda")
236+
B_packed = torch.full((N * K // 2,), 0x22, dtype=torch.uint8, device="cuda")
237+
238+
# Scale = 1.0 in UE4M3: exponent=7 (bias=7, so 2^0=1), mantissa=0
239+
# Code = (7 << 3) | 0 = 56 = 0x38
240+
A_scales = torch.full((M * (K // 16),), 0x38, dtype=torch.uint8, device="cuda")
241+
B_scales = torch.full((N * (K // 16),), 0x38, dtype=torch.uint8, device="cuda")
242+
243+
D_out = torch.zeros(M, N, dtype=torch.float32, device="cuda")
244+
245+
lib.cgemm_nvfp4(
246+
ctypes.c_void_p(A_packed.data_ptr()),
247+
ctypes.c_void_p(B_packed.data_ptr()),
248+
ctypes.c_void_p(A_scales.data_ptr()),
249+
ctypes.c_void_p(B_scales.data_ptr()),
250+
ctypes.c_void_p(D_out.data_ptr()),
251+
ctypes.c_int(M),
252+
ctypes.c_int(N),
253+
ctypes.c_int(K),
254+
)
255+
torch.cuda.synchronize()
256+
257+
# Each output element = sum of K products: 1.0 * 1.0 * K = 64
258+
expected = 64.0
259+
D_cpu = D_out.cpu()
260+
print(f"Identity test output:\n{D_cpu}")
261+
assert torch.allclose(D_cpu, torch.full((M, N), expected)), (
262+
f"Expected all {expected}, got min={D_cpu.min():.1f} max={D_cpu.max():.1f}"
263+
)
264+
265+
def test_gemm_nvfp4_multi_k_tiles(self):
266+
"""Test with K > 64 to verify K-loop accumulation."""
267+
lib = get_lib()
268+
M, N, K = 16, 8, 128 # 2 k-tiles
269+
270+
# All values = 1.0
271+
A_packed = torch.full((M * K // 2,), 0x22, dtype=torch.uint8, device="cuda")
272+
B_packed = torch.full((N * K // 2,), 0x22, dtype=torch.uint8, device="cuda")
273+
A_scales = torch.full((M * (K // 16),), 0x38, dtype=torch.uint8, device="cuda")
274+
B_scales = torch.full((N * (K // 16),), 0x38, dtype=torch.uint8, device="cuda")
275+
276+
D_out = torch.zeros(M, N, dtype=torch.float32, device="cuda")
277+
278+
lib.cgemm_nvfp4(
279+
ctypes.c_void_p(A_packed.data_ptr()),
280+
ctypes.c_void_p(B_packed.data_ptr()),
281+
ctypes.c_void_p(A_scales.data_ptr()),
282+
ctypes.c_void_p(B_scales.data_ptr()),
283+
ctypes.c_void_p(D_out.data_ptr()),
284+
ctypes.c_int(M),
285+
ctypes.c_int(N),
286+
ctypes.c_int(K),
287+
)
288+
torch.cuda.synchronize()
289+
290+
expected = float(K) # 1.0 * 1.0 * K
291+
D_cpu = D_out.cpu()
292+
print(f"Multi-K test output (expect {expected}):\n{D_cpu[0, :]}")
293+
assert torch.allclose(D_cpu, torch.full((M, N), expected)), (
294+
f"Expected all {expected}, got min={D_cpu.min():.1f} max={D_cpu.max():.1f}"
295+
)
296+
297+
298+
if __name__ == "__main__":
299+
pytest.main([__file__, "-v", "-s"])

0 commit comments

Comments
 (0)