@@ -266,6 +266,77 @@ def test_random_data_larger(self):
266266 print (f" Mean relative error: { rel_err :.4f} " )
267267 assert rel_err < 0.5 , f"Relative error { rel_err :.4f} too large"
268268
269+ def _run_gemm_test (self , M , N , K , seed = 42 ):
270+ """Helper: quantize random data, run GEMM, compare against reference."""
271+ torch .manual_seed (seed )
272+ A_float = torch .randn (M , K , dtype = torch .float32 , device = "cuda" )
273+ B_float = torch .randn (N , K , dtype = torch .float32 , device = "cuda" )
274+
275+ A_packed , A_scales , A_ts = cuda_quantize_nvfp4 (A_float .reshape (- 1 ))
276+ B_packed , B_scales , B_ts = cuda_quantize_nvfp4 (B_float .reshape (- 1 ))
277+
278+ A_deq = cuda_dequantize_nvfp4 (A_packed , A_scales , A_ts , M * K ).reshape (M , K )
279+ B_deq = cuda_dequantize_nvfp4 (B_packed , B_scales , B_ts , N * K ).reshape (N , K )
280+
281+ D_ref = A_deq @ B_deq .T
282+ D_kernel = cuda_gemm_nvfp4 (A_packed , B_packed , A_scales , B_scales , M , N , K )
283+ D_out = D_kernel * A_ts * B_ts
284+
285+ abs_err = (D_out - D_ref ).abs ()
286+ ref_mag = D_ref .abs ().mean ().item ()
287+ mean_err = abs_err .mean ().item ()
288+ max_err = abs_err .max ().item ()
289+
290+ if ref_mag > 0 :
291+ rel_err = mean_err / ref_mag
292+ else :
293+ rel_err = mean_err
294+
295+ return rel_err , max_err , mean_err , ref_mag
296+
297+ def test_gemm_medium (self ):
298+ """Medium matrices (128x128x128) — multiple tiles in all dimensions."""
299+ rel_err , max_err , mean_err , ref_mag = self ._run_gemm_test (128 , 128 , 128 )
300+ print (f"Medium (128x128x128): rel_err={ rel_err :.6f} , max_err={ max_err :.4f} " )
301+ assert rel_err < 0.01 , f"Relative error { rel_err :.6f} too large"
302+
303+ def test_gemm_large (self ):
304+ """Larger matrices (256x256x256)."""
305+ rel_err , max_err , mean_err , ref_mag = self ._run_gemm_test (256 , 256 , 256 )
306+ print (f"Large (256x256x256): rel_err={ rel_err :.6f} , max_err={ max_err :.4f} " )
307+ assert rel_err < 0.01 , f"Relative error { rel_err :.6f} too large"
308+
309+ @pytest .mark .parametrize (
310+ "M,N,K" ,
311+ [
312+ (16 , 8 , 128 ), # Single M/N tile, multi K
313+ (48 , 24 , 64 ), # M,N not multiples of tile (16,8)
314+ (32 , 8 , 192 ), # K not multiple of 64 (3 K-tiles)
315+ (80 , 40 , 64 ), # Larger non-aligned M,N
316+ ],
317+ ids = ["16x8x128" , "48x24x64" , "32x8x192" , "80x40x64" ],
318+ )
319+ def test_gemm_various_shapes (self , M , N , K ):
320+ """Test various matrix shapes including non-tile-aligned."""
321+ rel_err , max_err , mean_err , ref_mag = self ._run_gemm_test (M , N , K )
322+ print (f"Shape ({ M } x{ N } x{ K } ): rel_err={ rel_err :.6f} , ref_mag={ ref_mag :.4f} " )
323+ assert rel_err < 0.01 , f"Relative error { rel_err :.6f} too large for { M } x{ N } x{ K } "
324+
325+ @pytest .mark .parametrize (
326+ "M,N,K" ,
327+ [
328+ (1 , 128 , 64 ), # Single row (batch=1 inference)
329+ (8 , 128 , 64 ), # Small batch
330+ (32 , 128 , 128 ), # Medium batch
331+ ],
332+ ids = ["1x128x64" , "8x128x64" , "32x128x128" ],
333+ )
334+ def test_gemm_tall_skinny (self , M , N , K ):
335+ """Test tall/skinny shapes typical of LLM inference."""
336+ rel_err , max_err , mean_err , ref_mag = self ._run_gemm_test (M , N , K )
337+ print (f"Tall/skinny ({ M } x{ N } x{ K } ): rel_err={ rel_err :.6f} , ref_mag={ ref_mag :.4f} " )
338+ assert rel_err < 0.01 , f"Relative error { rel_err :.6f} too large for { M } x{ N } x{ K } "
339+
269340
270341if __name__ == "__main__" :
271342 pytest .main ([__file__ , "-v" , "-s" ])
0 commit comments