Skip to content

Commit 47f4af3

Browse files
TimDettmersclaude
andcommitted
test: Add comprehensive GEMM correctness tests for various shapes
Adds medium (128x128x128), large (256x256x256), non-aligned (48x24x64, 32x8x192, 80x40x64), and tall/skinny (1x128x64, 8x128x64, 32x128x128) test cases covering Task 10 requirements. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bfe0916 commit 47f4af3

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

tests/test_gemm_nvfp4.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,77 @@ def test_random_data_larger(self):
266266
print(f" Mean relative error: {rel_err:.4f}")
267267
assert rel_err < 0.5, f"Relative error {rel_err:.4f} too large"
268268

269+
def _run_gemm_test(self, M, N, K, seed=42):
270+
"""Helper: quantize random data, run GEMM, compare against reference."""
271+
torch.manual_seed(seed)
272+
A_float = torch.randn(M, K, dtype=torch.float32, device="cuda")
273+
B_float = torch.randn(N, K, dtype=torch.float32, device="cuda")
274+
275+
A_packed, A_scales, A_ts = cuda_quantize_nvfp4(A_float.reshape(-1))
276+
B_packed, B_scales, B_ts = cuda_quantize_nvfp4(B_float.reshape(-1))
277+
278+
A_deq = cuda_dequantize_nvfp4(A_packed, A_scales, A_ts, M * K).reshape(M, K)
279+
B_deq = cuda_dequantize_nvfp4(B_packed, B_scales, B_ts, N * K).reshape(N, K)
280+
281+
D_ref = A_deq @ B_deq.T
282+
D_kernel = cuda_gemm_nvfp4(A_packed, B_packed, A_scales, B_scales, M, N, K)
283+
D_out = D_kernel * A_ts * B_ts
284+
285+
abs_err = (D_out - D_ref).abs()
286+
ref_mag = D_ref.abs().mean().item()
287+
mean_err = abs_err.mean().item()
288+
max_err = abs_err.max().item()
289+
290+
if ref_mag > 0:
291+
rel_err = mean_err / ref_mag
292+
else:
293+
rel_err = mean_err
294+
295+
return rel_err, max_err, mean_err, ref_mag
296+
297+
def test_gemm_medium(self):
298+
"""Medium matrices (128x128x128) — multiple tiles in all dimensions."""
299+
rel_err, max_err, mean_err, ref_mag = self._run_gemm_test(128, 128, 128)
300+
print(f"Medium (128x128x128): rel_err={rel_err:.6f}, max_err={max_err:.4f}")
301+
assert rel_err < 0.01, f"Relative error {rel_err:.6f} too large"
302+
303+
def test_gemm_large(self):
304+
"""Larger matrices (256x256x256)."""
305+
rel_err, max_err, mean_err, ref_mag = self._run_gemm_test(256, 256, 256)
306+
print(f"Large (256x256x256): rel_err={rel_err:.6f}, max_err={max_err:.4f}")
307+
assert rel_err < 0.01, f"Relative error {rel_err:.6f} too large"
308+
309+
@pytest.mark.parametrize(
310+
"M,N,K",
311+
[
312+
(16, 8, 128), # Single M/N tile, multi K
313+
(48, 24, 64), # M,N not multiples of tile (16,8)
314+
(32, 8, 192), # K not multiple of 64 (3 K-tiles)
315+
(80, 40, 64), # Larger non-aligned M,N
316+
],
317+
ids=["16x8x128", "48x24x64", "32x8x192", "80x40x64"],
318+
)
319+
def test_gemm_various_shapes(self, M, N, K):
320+
"""Test various matrix shapes including non-tile-aligned."""
321+
rel_err, max_err, mean_err, ref_mag = self._run_gemm_test(M, N, K)
322+
print(f"Shape ({M}x{N}x{K}): rel_err={rel_err:.6f}, ref_mag={ref_mag:.4f}")
323+
assert rel_err < 0.01, f"Relative error {rel_err:.6f} too large for {M}x{N}x{K}"
324+
325+
@pytest.mark.parametrize(
326+
"M,N,K",
327+
[
328+
(1, 128, 64), # Single row (batch=1 inference)
329+
(8, 128, 64), # Small batch
330+
(32, 128, 128), # Medium batch
331+
],
332+
ids=["1x128x64", "8x128x64", "32x128x128"],
333+
)
334+
def test_gemm_tall_skinny(self, M, N, K):
335+
"""Test tall/skinny shapes typical of LLM inference."""
336+
rel_err, max_err, mean_err, ref_mag = self._run_gemm_test(M, N, K)
337+
print(f"Tall/skinny ({M}x{N}x{K}): rel_err={rel_err:.6f}, ref_mag={ref_mag:.4f}")
338+
assert rel_err < 0.01, f"Relative error {rel_err:.6f} too large for {M}x{N}x{K}"
339+
269340

270341
if __name__ == "__main__":
271342
pytest.main([__file__, "-v", "-s"])

0 commit comments

Comments
 (0)