@@ -191,7 +191,11 @@ class TestGemmNVFP4:
191191 """Test NVFP4 GEMM kernel correctness."""
192192
193193 def _run_gemm (self , M , N , K , seed = 42 ):
194- """Run the GEMM kernel and return (output, reference)."""
194+ """Run the GEMM kernel and return (output, reference).
195+
196+ The kernel computes D_raw = (A_fp4 * SFA) @ (B_fp4 * SFB)^T.
197+ The tensor scales are applied post-hoc: D = D_raw * A_ts * B_ts.
198+ """
195199 lib = get_lib ()
196200 assert hasattr (lib , "cgemm_nvfp4" ), "cgemm_nvfp4 symbol not found in library"
197201
@@ -211,19 +215,31 @@ def _run_gemm(self, M, N, K, seed=42):
211215 )
212216 torch .cuda .synchronize ()
213217
214- return D_out .cpu (), D_ref
218+ # Apply tensor scales (not handled by kernel)
219+ D_out_scaled = D_out .cpu () * A_ts * B_ts
220+
221+ return D_out_scaled , D_ref
215222
216- def test_gemm_nvfp4_minimal (self ):
217- """Test 16x8x64 (single MMA tile)."""
223+ def test_gemm_nvfp4_random_single_tile (self ):
224+ """Test 16x8x64 (single MMA tile) with random data ."""
218225 D_out , D_ref = self ._run_gemm (16 , 8 , 64 )
219226 print (f"Output[0:4, 0:4]:\n { D_out [0 :4 , 0 :4 ]} " )
220227 print (f"Reference[0:4, 0:4]:\n { D_ref [0 :4 , 0 :4 ]} " )
221- # Just check it runs and produces finite values
222228 assert torch .isfinite (D_out ).all (), "Output contains non-finite values"
223- # Check rough magnitude match (within 10x)
224- if D_ref .abs ().max () > 0 :
225- ratio = D_out .abs ().max () / D_ref .abs ().max ()
226- print (f"Max magnitude ratio (out/ref): { ratio :.3f} " )
229+ # Compare: both are products of FP4-quantized values, so they should
230+ # be close. The main error source is quantization of the input.
231+ abs_err = (D_out - D_ref ).abs ()
232+ max_abs_err = abs_err .max ().item ()
233+ mean_abs_err = abs_err .mean ().item ()
234+ ref_magnitude = D_ref .abs ().mean ().item ()
235+ print (f"Max abs error: { max_abs_err :.4f} " )
236+ print (f"Mean abs error: { mean_abs_err :.4f} " )
237+ print (f"Reference mean magnitude: { ref_magnitude :.4f} " )
238+ # Relative error should be reasonable (FP4 quantization has ~25% relative error)
239+ if ref_magnitude > 0 :
240+ rel_err = mean_abs_err / ref_magnitude
241+ print (f"Relative error: { rel_err :.4f} " )
242+ assert rel_err < 2.0 , f"Relative error { rel_err :.4f} too large"
227243
228244 def test_gemm_nvfp4_identity_scales (self ):
229245 """Test with all-ones data and scale=1 to verify basic MMA correctness."""
0 commit comments