Skip to content

Commit 08805f6

Browse files
TimDettmersclaude
andcommitted
fix: Account for tensor scales in GEMM test, improve random test
The GEMM kernel produces raw block-scaled output without tensor scales. Apply A_ts * B_ts post-hoc. Add quantitative error checks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 92dc4ee commit 08805f6

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

tests/test_gemm_nvfp4.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,11 @@ class TestGemmNVFP4:
191191
"""Test NVFP4 GEMM kernel correctness."""
192192

193193
def _run_gemm(self, M, N, K, seed=42):
194-
"""Run the GEMM kernel and return (output, reference)."""
194+
"""Run the GEMM kernel and return (output, reference).
195+
196+
The kernel computes D_raw = (A_fp4 * SFA) @ (B_fp4 * SFB)^T.
197+
The tensor scales are applied post-hoc: D = D_raw * A_ts * B_ts.
198+
"""
195199
lib = get_lib()
196200
assert hasattr(lib, "cgemm_nvfp4"), "cgemm_nvfp4 symbol not found in library"
197201

@@ -211,19 +215,31 @@ def _run_gemm(self, M, N, K, seed=42):
211215
)
212216
torch.cuda.synchronize()
213217

214-
return D_out.cpu(), D_ref
218+
# Apply tensor scales (not handled by kernel)
219+
D_out_scaled = D_out.cpu() * A_ts * B_ts
220+
221+
return D_out_scaled, D_ref
215222

216-
def test_gemm_nvfp4_minimal(self):
217-
"""Test 16x8x64 (single MMA tile)."""
223+
def test_gemm_nvfp4_random_single_tile(self):
224+
"""Test 16x8x64 (single MMA tile) with random data."""
218225
D_out, D_ref = self._run_gemm(16, 8, 64)
219226
print(f"Output[0:4, 0:4]:\n{D_out[0:4, 0:4]}")
220227
print(f"Reference[0:4, 0:4]:\n{D_ref[0:4, 0:4]}")
221-
# Just check it runs and produces finite values
222228
assert torch.isfinite(D_out).all(), "Output contains non-finite values"
223-
# Check rough magnitude match (within 10x)
224-
if D_ref.abs().max() > 0:
225-
ratio = D_out.abs().max() / D_ref.abs().max()
226-
print(f"Max magnitude ratio (out/ref): {ratio:.3f}")
229+
# Compare: both are products of FP4-quantized values, so they should
230+
# be close. The main error source is quantization of the input.
231+
abs_err = (D_out - D_ref).abs()
232+
max_abs_err = abs_err.max().item()
233+
mean_abs_err = abs_err.mean().item()
234+
ref_magnitude = D_ref.abs().mean().item()
235+
print(f"Max abs error: {max_abs_err:.4f}")
236+
print(f"Mean abs error: {mean_abs_err:.4f}")
237+
print(f"Reference mean magnitude: {ref_magnitude:.4f}")
238+
# Relative error should be reasonable (FP4 quantization has ~25% relative error)
239+
if ref_magnitude > 0:
240+
rel_err = mean_abs_err / ref_magnitude
241+
print(f"Relative error: {rel_err:.4f}")
242+
assert rel_err < 2.0, f"Relative error {rel_err:.4f} too large"
227243

228244
def test_gemm_nvfp4_identity_scales(self):
229245
"""Test with all-ones data and scale=1 to verify basic MMA correctness."""

0 commit comments

Comments
 (0)