|
| 1 | +"""Test NVFP4 GEMM kernel on SM_120 (Blackwell consumer GPUs). |
| 2 | +
|
| 3 | +Tests the block-scaled mma.sync GEMM kernel via ctypes. |
| 4 | +""" |
| 5 | + |
| 6 | +import ctypes |
| 7 | +import os |
| 8 | + |
| 9 | +import pytest |
| 10 | +import torch |
| 11 | + |
| 12 | + |
| 13 | +def get_lib(): |
| 14 | + """Load the bitsandbytes CUDA library.""" |
| 15 | + lib_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "bitsandbytes") |
| 16 | + for suffix in ["cuda131", "cuda130"]: |
| 17 | + lib_path = os.path.join(lib_dir, f"libbitsandbytes_{suffix}.so") |
| 18 | + if os.path.exists(lib_path): |
| 19 | + return ctypes.cdll.LoadLibrary(lib_path) |
| 20 | + raise RuntimeError(f"Could not find bitsandbytes CUDA library in {lib_dir}") |
| 21 | + |
| 22 | + |
| 23 | +# E2M1 representable magnitudes (unsigned) |
| 24 | +E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] |
| 25 | + |
| 26 | + |
| 27 | +def float_to_e2m1(x): |
| 28 | + """Quantize a float to nearest E2M1 value (magnitude only).""" |
| 29 | + ax = abs(x) |
| 30 | + # Decision boundaries (midpoints between consecutive E2M1 values) |
| 31 | + boundaries = [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] |
| 32 | + for i, b in enumerate(boundaries): |
| 33 | + if ax < b: |
| 34 | + return E2M1_VALUES[i] * (1 if x >= 0 else -1) |
| 35 | + return E2M1_VALUES[7] * (1 if x >= 0 else -1) |
| 36 | + |
| 37 | + |
| 38 | +def float_to_e4m3(x): |
| 39 | + """Quantize a positive float to UE4M3 (unsigned E4M3, bias=7).""" |
| 40 | + if x <= 0: |
| 41 | + return 0, 0.0 |
| 42 | + # Clamp to max representable value (~448) |
| 43 | + x = min(x, 448.0) |
| 44 | + if x == 0: |
| 45 | + return 0, 0.0 |
| 46 | + # Find the exponent |
| 47 | + import math |
| 48 | + e = math.floor(math.log2(x)) |
| 49 | + e = max(e, -6) # min exponent with bias=7 is -6 |
| 50 | + e = min(e, 8) # max exponent with bias=7 is 8 |
| 51 | + # Mantissa |
| 52 | + m = x / (2.0 ** e) - 1.0 |
| 53 | + m = max(0, min(m, 0.875)) # 3 mantissa bits -> 7/8 max |
| 54 | + m_int = round(m * 8) |
| 55 | + m_int = min(m_int, 7) |
| 56 | + # Encode |
| 57 | + e_biased = e + 7 |
| 58 | + e_biased = max(0, min(e_biased, 15)) |
| 59 | + code = (e_biased << 3) | m_int |
| 60 | + # Decode to get actual value |
| 61 | + actual = (1.0 + m_int / 8.0) * (2.0 ** (e_biased - 7)) |
| 62 | + if e_biased == 0: |
| 63 | + actual = m_int / 8.0 * (2.0 ** -6) |
| 64 | + return code, actual |
| 65 | + |
| 66 | + |
| 67 | +def quantize_tensor_reference(x_flat): |
| 68 | + """Reference quantization: float tensor -> packed FP4 + block scales + tensor scale. |
| 69 | +
|
| 70 | + Returns (packed_bytes, block_scale_bytes, tensor_scale) in the format |
| 71 | + expected by the GEMM kernel. |
| 72 | + """ |
| 73 | + n = len(x_flat) |
| 74 | + assert n % 16 == 0 |
| 75 | + num_blocks = n // 16 |
| 76 | + |
| 77 | + tensor_scale = max(abs(v) for v in x_flat) |
| 78 | + if tensor_scale == 0: |
| 79 | + tensor_scale = 1.0 |
| 80 | + |
| 81 | + packed = [] |
| 82 | + block_scales = [] |
| 83 | + |
| 84 | + for b in range(num_blocks): |
| 85 | + block = x_flat[b * 16:(b + 1) * 16] |
| 86 | + # Normalize by tensor scale |
| 87 | + normalized = [v / tensor_scale for v in block] |
| 88 | + # Block absmax |
| 89 | + block_max = max(abs(v) for v in normalized) |
| 90 | + if block_max == 0: |
| 91 | + block_max = 1e-10 |
| 92 | + |
| 93 | + # Block scale = block_max / 6.0 (max E2M1 value) |
| 94 | + raw_scale = block_max / 6.0 |
| 95 | + scale_code, scale_actual = float_to_e4m3(raw_scale) |
| 96 | + block_scales.append(scale_code) |
| 97 | + |
| 98 | + if scale_actual == 0: |
| 99 | + scale_actual = 1e-10 |
| 100 | + |
| 101 | + # Quantize each element |
| 102 | + nibbles = [] |
| 103 | + for v in normalized: |
| 104 | + scaled_v = v / scale_actual |
| 105 | + qval = float_to_e2m1(scaled_v) |
| 106 | + # Encode: sign in bit 3, magnitude in bits 0-2 |
| 107 | + mag = abs(qval) |
| 108 | + mag_idx = E2M1_VALUES.index(mag) if mag in E2M1_VALUES else 0 |
| 109 | + code = mag_idx |
| 110 | + if qval < 0: |
| 111 | + code |= 0x8 |
| 112 | + nibbles.append(code) |
| 113 | + |
| 114 | + # Pack 2 per byte (low nibble = even index, high nibble = odd index) |
| 115 | + for i in range(0, 16, 2): |
| 116 | + byte_val = (nibbles[i] & 0xF) | ((nibbles[i + 1] & 0xF) << 4) |
| 117 | + packed.append(byte_val) |
| 118 | + |
| 119 | + return packed, block_scales, tensor_scale |
| 120 | + |
| 121 | + |
| 122 | +def dequantize_reference(packed, block_scales, tensor_scale, M, K): |
| 123 | + """Reference dequantization for verification.""" |
| 124 | + n = M * K |
| 125 | + result = [] |
| 126 | + for i in range(n): |
| 127 | + byte_idx = i // 2 |
| 128 | + block_idx = i // 16 |
| 129 | + byte_val = packed[byte_idx] |
| 130 | + if i % 2 == 0: |
| 131 | + code = byte_val & 0xF |
| 132 | + else: |
| 133 | + code = (byte_val >> 4) & 0xF |
| 134 | + |
| 135 | + sign = -1.0 if (code & 0x8) else 1.0 |
| 136 | + mag_idx = code & 0x7 |
| 137 | + mag = E2M1_VALUES[mag_idx] |
| 138 | + |
| 139 | + # Decode block scale |
| 140 | + sf_code = block_scales[block_idx] |
| 141 | + sf_e = (sf_code >> 3) & 0xF |
| 142 | + sf_m = sf_code & 0x7 |
| 143 | + if sf_e == 0: |
| 144 | + sf_val = sf_m / 8.0 * (2.0 ** -6) |
| 145 | + else: |
| 146 | + sf_val = (1.0 + sf_m / 8.0) * (2.0 ** (sf_e - 7)) |
| 147 | + |
| 148 | + result.append(sign * mag * sf_val * tensor_scale) |
| 149 | + return result |
| 150 | + |
| 151 | + |
| 152 | +def prepare_gemm_inputs(M, N, K, seed=42): |
| 153 | + """Create random FP4-quantized inputs for GEMM testing. |
| 154 | +
|
| 155 | + Returns CUDA tensors ready for the GEMM kernel, plus reference |
| 156 | + dequantized matrices for verification. |
| 157 | + """ |
| 158 | + import random |
| 159 | + random.seed(seed) |
| 160 | + |
| 161 | + # Generate random float values |
| 162 | + A_flat = [random.gauss(0, 1) for _ in range(M * K)] |
| 163 | + B_flat = [random.gauss(0, 1) for _ in range(N * K)] # B is N x K (transposed) |
| 164 | + |
| 165 | + # Quantize |
| 166 | + A_packed, A_sf, A_ts = quantize_tensor_reference(A_flat) |
| 167 | + B_packed, B_sf, B_ts = quantize_tensor_reference(B_flat) |
| 168 | + |
| 169 | + # Dequantize for reference |
| 170 | + A_deq = dequantize_reference(A_packed, A_sf, A_ts, M, K) |
| 171 | + B_deq = dequantize_reference(B_packed, B_sf, B_ts, N, K) |
| 172 | + |
| 173 | + # Reshape for torch.matmul: A is M x K, B^T is N x K -> B is K x N |
| 174 | + A_ref = torch.tensor(A_deq, dtype=torch.float32).reshape(M, K) |
| 175 | + B_ref = torch.tensor(B_deq, dtype=torch.float32).reshape(N, K).T # K x N |
| 176 | + |
| 177 | + # Reference output |
| 178 | + D_ref = A_ref @ B_ref # M x N |
| 179 | + |
| 180 | + # Create CUDA tensors |
| 181 | + A_data = torch.tensor(A_packed, dtype=torch.uint8, device="cuda") |
| 182 | + B_data = torch.tensor(B_packed, dtype=torch.uint8, device="cuda") |
| 183 | + A_scales = torch.tensor(A_sf, dtype=torch.uint8, device="cuda") |
| 184 | + B_scales = torch.tensor(B_sf, dtype=torch.uint8, device="cuda") |
| 185 | + |
| 186 | + return A_data, B_data, A_scales, B_scales, A_ts, B_ts, D_ref |
| 187 | + |
| 188 | + |
| 189 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 190 | +class TestGemmNVFP4: |
| 191 | + """Test NVFP4 GEMM kernel correctness.""" |
| 192 | + |
| 193 | + def _run_gemm(self, M, N, K, seed=42): |
| 194 | + """Run the GEMM kernel and return (output, reference).""" |
| 195 | + lib = get_lib() |
| 196 | + assert hasattr(lib, "cgemm_nvfp4"), "cgemm_nvfp4 symbol not found in library" |
| 197 | + |
| 198 | + A_data, B_data, A_scales, B_scales, A_ts, B_ts, D_ref = prepare_gemm_inputs(M, N, K, seed) |
| 199 | + |
| 200 | + D_out = torch.zeros(M, N, dtype=torch.float32, device="cuda") |
| 201 | + |
| 202 | + lib.cgemm_nvfp4( |
| 203 | + ctypes.c_void_p(A_data.data_ptr()), |
| 204 | + ctypes.c_void_p(B_data.data_ptr()), |
| 205 | + ctypes.c_void_p(A_scales.data_ptr()), |
| 206 | + ctypes.c_void_p(B_scales.data_ptr()), |
| 207 | + ctypes.c_void_p(D_out.data_ptr()), |
| 208 | + ctypes.c_int(M), |
| 209 | + ctypes.c_int(N), |
| 210 | + ctypes.c_int(K), |
| 211 | + ) |
| 212 | + torch.cuda.synchronize() |
| 213 | + |
| 214 | + return D_out.cpu(), D_ref |
| 215 | + |
| 216 | + def test_gemm_nvfp4_minimal(self): |
| 217 | + """Test 16x8x64 (single MMA tile).""" |
| 218 | + D_out, D_ref = self._run_gemm(16, 8, 64) |
| 219 | + print(f"Output[0:4, 0:4]:\n{D_out[0:4, 0:4]}") |
| 220 | + print(f"Reference[0:4, 0:4]:\n{D_ref[0:4, 0:4]}") |
| 221 | + # Just check it runs and produces finite values |
| 222 | + assert torch.isfinite(D_out).all(), "Output contains non-finite values" |
| 223 | + # Check rough magnitude match (within 10x) |
| 224 | + if D_ref.abs().max() > 0: |
| 225 | + ratio = D_out.abs().max() / D_ref.abs().max() |
| 226 | + print(f"Max magnitude ratio (out/ref): {ratio:.3f}") |
| 227 | + |
| 228 | + def test_gemm_nvfp4_identity_scales(self): |
| 229 | + """Test with all-ones data and scale=1 to verify basic MMA correctness.""" |
| 230 | + lib = get_lib() |
| 231 | + M, N, K = 16, 8, 64 |
| 232 | + |
| 233 | + # All values = 1.0 in E2M1: code = 0b0010 = 2 |
| 234 | + # Pack: byte = (2) | (2 << 4) = 0x22 |
| 235 | + A_packed = torch.full((M * K // 2,), 0x22, dtype=torch.uint8, device="cuda") |
| 236 | + B_packed = torch.full((N * K // 2,), 0x22, dtype=torch.uint8, device="cuda") |
| 237 | + |
| 238 | + # Scale = 1.0 in UE4M3: exponent=7 (bias=7, so 2^0=1), mantissa=0 |
| 239 | + # Code = (7 << 3) | 0 = 56 = 0x38 |
| 240 | + A_scales = torch.full((M * (K // 16),), 0x38, dtype=torch.uint8, device="cuda") |
| 241 | + B_scales = torch.full((N * (K // 16),), 0x38, dtype=torch.uint8, device="cuda") |
| 242 | + |
| 243 | + D_out = torch.zeros(M, N, dtype=torch.float32, device="cuda") |
| 244 | + |
| 245 | + lib.cgemm_nvfp4( |
| 246 | + ctypes.c_void_p(A_packed.data_ptr()), |
| 247 | + ctypes.c_void_p(B_packed.data_ptr()), |
| 248 | + ctypes.c_void_p(A_scales.data_ptr()), |
| 249 | + ctypes.c_void_p(B_scales.data_ptr()), |
| 250 | + ctypes.c_void_p(D_out.data_ptr()), |
| 251 | + ctypes.c_int(M), |
| 252 | + ctypes.c_int(N), |
| 253 | + ctypes.c_int(K), |
| 254 | + ) |
| 255 | + torch.cuda.synchronize() |
| 256 | + |
| 257 | + # Each output element = sum of K products: 1.0 * 1.0 * K = 64 |
| 258 | + expected = 64.0 |
| 259 | + D_cpu = D_out.cpu() |
| 260 | + print(f"Identity test output:\n{D_cpu}") |
| 261 | + assert torch.allclose(D_cpu, torch.full((M, N), expected)), ( |
| 262 | + f"Expected all {expected}, got min={D_cpu.min():.1f} max={D_cpu.max():.1f}" |
| 263 | + ) |
| 264 | + |
| 265 | + def test_gemm_nvfp4_multi_k_tiles(self): |
| 266 | + """Test with K > 64 to verify K-loop accumulation.""" |
| 267 | + lib = get_lib() |
| 268 | + M, N, K = 16, 8, 128 # 2 k-tiles |
| 269 | + |
| 270 | + # All values = 1.0 |
| 271 | + A_packed = torch.full((M * K // 2,), 0x22, dtype=torch.uint8, device="cuda") |
| 272 | + B_packed = torch.full((N * K // 2,), 0x22, dtype=torch.uint8, device="cuda") |
| 273 | + A_scales = torch.full((M * (K // 16),), 0x38, dtype=torch.uint8, device="cuda") |
| 274 | + B_scales = torch.full((N * (K // 16),), 0x38, dtype=torch.uint8, device="cuda") |
| 275 | + |
| 276 | + D_out = torch.zeros(M, N, dtype=torch.float32, device="cuda") |
| 277 | + |
| 278 | + lib.cgemm_nvfp4( |
| 279 | + ctypes.c_void_p(A_packed.data_ptr()), |
| 280 | + ctypes.c_void_p(B_packed.data_ptr()), |
| 281 | + ctypes.c_void_p(A_scales.data_ptr()), |
| 282 | + ctypes.c_void_p(B_scales.data_ptr()), |
| 283 | + ctypes.c_void_p(D_out.data_ptr()), |
| 284 | + ctypes.c_int(M), |
| 285 | + ctypes.c_int(N), |
| 286 | + ctypes.c_int(K), |
| 287 | + ) |
| 288 | + torch.cuda.synchronize() |
| 289 | + |
| 290 | + expected = float(K) # 1.0 * 1.0 * K |
| 291 | + D_cpu = D_out.cpu() |
| 292 | + print(f"Multi-K test output (expect {expected}):\n{D_cpu[0, :]}") |
| 293 | + assert torch.allclose(D_cpu, torch.full((M, N), expected)), ( |
| 294 | + f"Expected all {expected}, got min={D_cpu.min():.1f} max={D_cpu.max():.1f}" |
| 295 | + ) |
| 296 | + |
| 297 | + |
| 298 | +if __name__ == "__main__": |
| 299 | + pytest.main([__file__, "-v", "-s"]) |
0 commit comments