@@ -338,5 +338,179 @@ def test_gemm_tall_skinny(self, M, N, K):
338338 assert rel_err < 0.01 , f"Relative error { rel_err :.6f} too large for { M } x{ N } x{ K } "
339339
340340
341+ class TestGemmNVFP4Output :
342+ """Test GEMM with NVFP4 output (layer chaining) via Python API."""
343+
344+ def test_gemm_nvfp4_output_basic (self ):
345+ """GEMM with NVFP4 output: quantize → GEMM → quantize output → dequantize → compare."""
346+ from bitsandbytes .functional import (
347+ dequantize_nvfp4 ,
348+ gemm_nvfp4_to_nvfp4 ,
349+ quantize_nvfp4 ,
350+ )
351+
352+ torch .manual_seed (42 )
353+ M , N , K = 32 , 32 , 64
354+
355+ A_float = torch .randn (M , K , dtype = torch .float32 , device = "cuda" )
356+ B_float = torch .randn (N , K , dtype = torch .float32 , device = "cuda" )
357+
358+ # Quantize inputs
359+ A_packed , A_state = quantize_nvfp4 (A_float )
360+ B_packed , B_state = quantize_nvfp4 (B_float )
361+
362+ # GEMM with NVFP4 output
363+ out_packed , out_state = gemm_nvfp4_to_nvfp4 (A_packed , A_state , B_packed , B_state )
364+
365+ # Dequantize output
366+ D_deq = dequantize_nvfp4 (out_packed , out_state , out_dtype = torch .float32 )
367+
368+ # Reference: dequantize inputs → matmul
369+ A_deq = dequantize_nvfp4 (A_packed , A_state , out_dtype = torch .float32 )
370+ B_deq = dequantize_nvfp4 (B_packed , B_state , out_dtype = torch .float32 )
371+ D_ref = A_deq @ B_deq .T
372+
373+ # NVFP4 output adds a second layer of quantization error
374+ ref_mag = D_ref .abs ().mean ().item ()
375+ mean_err = (D_deq - D_ref ).abs ().mean ().item ()
376+ rel_err = mean_err / ref_mag if ref_mag > 0 else mean_err
377+
378+ print (f"GEMM NVFP4 output (M={ M } , N={ N } , K={ K } ):" )
379+ print (f" Reference magnitude: { ref_mag :.4f} " )
380+ print (f" Mean abs error: { mean_err :.4f} " )
381+ print (f" Relative error: { rel_err :.4f} " )
382+ print (f" Output shape: { D_deq .shape } " )
383+
384+ assert D_deq .shape == (M , N ), f"Wrong shape: { D_deq .shape } "
385+ # Double quantization error: once for inputs, once for output
386+ assert rel_err < 0.5 , f"Relative error { rel_err :.4f} too large"
387+
388+ def test_gemm_nvfp4_output_alpha (self ):
389+ """GEMM with alpha scaling and NVFP4 output."""
390+ from bitsandbytes .functional import (
391+ dequantize_nvfp4 ,
392+ gemm_nvfp4 ,
393+ gemm_nvfp4_to_nvfp4 ,
394+ quantize_nvfp4 ,
395+ )
396+
397+ torch .manual_seed (123 )
398+ M , N , K = 16 , 16 , 64
399+ alpha = 2.5
400+
401+ A_float = torch .randn (M , K , dtype = torch .float32 , device = "cuda" )
402+ B_float = torch .randn (N , K , dtype = torch .float32 , device = "cuda" )
403+
404+ A_packed , A_state = quantize_nvfp4 (A_float )
405+ B_packed , B_state = quantize_nvfp4 (B_float )
406+
407+ # GEMM without alpha (FP32 output)
408+ D_fp32 = gemm_nvfp4 (A_packed , A_state , B_packed , B_state )
409+
410+ # GEMM with alpha and NVFP4 output
411+ out_packed , out_state = gemm_nvfp4_to_nvfp4 (
412+ A_packed , A_state , B_packed , B_state , alpha = alpha
413+ )
414+ D_nvfp4 = dequantize_nvfp4 (out_packed , out_state , out_dtype = torch .float32 )
415+
416+ # Reference: alpha * FP32 output
417+ D_ref = D_fp32 * alpha
418+
419+ # Verify alpha is reflected in the output (within NVFP4 quantization error)
420+ ref_mag = D_ref .abs ().mean ().item ()
421+ mean_err = (D_nvfp4 - D_ref ).abs ().mean ().item ()
422+ rel_err = mean_err / ref_mag if ref_mag > 0 else mean_err
423+
424+ print (f"Alpha test (alpha={ alpha } ): rel_err={ rel_err :.4f} " )
425+ assert rel_err < 0.5 , f"Relative error { rel_err :.4f} too large"
426+
427+ def test_gemm_nvfp4_output_non_aligned_N (self ):
428+ """GEMM with NVFP4 output where N is not a multiple of 16."""
429+ from bitsandbytes .functional import (
430+ dequantize_nvfp4 ,
431+ gemm_nvfp4_to_nvfp4 ,
432+ quantize_nvfp4 ,
433+ )
434+
435+ torch .manual_seed (77 )
436+ M , N , K = 16 , 24 , 64 # N=24, not multiple of 16
437+
438+ A_float = torch .randn (M , K , dtype = torch .float32 , device = "cuda" )
439+ B_float = torch .randn (N , K , dtype = torch .float32 , device = "cuda" )
440+
441+ A_packed , A_state = quantize_nvfp4 (A_float )
442+ B_packed , B_state = quantize_nvfp4 (B_float )
443+
444+ out_packed , out_state = gemm_nvfp4_to_nvfp4 (A_packed , A_state , B_packed , B_state )
445+ D_deq = dequantize_nvfp4 (out_packed , out_state , out_dtype = torch .float32 )
446+
447+ # Reference
448+ A_deq = dequantize_nvfp4 (A_packed , A_state , out_dtype = torch .float32 )
449+ B_deq = dequantize_nvfp4 (B_packed , B_state , out_dtype = torch .float32 )
450+ D_ref = A_deq @ B_deq .T
451+
452+ assert D_deq .shape == (M , N ), f"Wrong shape: { D_deq .shape } "
453+ ref_mag = D_ref .abs ().mean ().item ()
454+ mean_err = (D_deq - D_ref ).abs ().mean ().item ()
455+ rel_err = mean_err / ref_mag if ref_mag > 0 else mean_err
456+ print (f"Non-aligned N test ({ M } x{ N } x{ K } ): rel_err={ rel_err :.4f} " )
457+ assert rel_err < 0.5 , f"Relative error { rel_err :.4f} too large"
458+
459+
460+ class TestNVFP4QuantStateSerialization :
461+ """Test NVFP4QuantState save/load."""
462+
463+ def test_state_dict_round_trip (self ):
464+ """Serialize and deserialize NVFP4QuantState."""
465+ from bitsandbytes .functional import NVFP4QuantState , dequantize_nvfp4 , quantize_nvfp4
466+
467+ torch .manual_seed (42 )
468+ x = torch .randn (256 , dtype = torch .float32 , device = "cuda" )
469+ packed , state = quantize_nvfp4 (x )
470+
471+ # Serialize
472+ sd = state .state_dict ()
473+ assert "packed_data" in sd
474+ assert "block_scales" in sd
475+ assert "tensor_scale" in sd
476+ assert "shape" in sd
477+ assert "dtype" in sd
478+
479+ # Deserialize
480+ state2 = NVFP4QuantState .from_state_dict (sd , device = "cuda" )
481+
482+ # Verify fields match
483+ assert torch .equal (state .packed_data , state2 .packed_data )
484+ assert torch .equal (state .block_scales , state2 .block_scales )
485+ assert state .tensor_scale == state2 .tensor_scale
486+ assert state .shape == state2 .shape
487+ assert state .dtype == state2 .dtype
488+ assert state .rotated == state2 .rotated
489+
490+ # Verify dequantization produces same result
491+ out1 = dequantize_nvfp4 (packed , state , out_dtype = torch .float32 )
492+ out2 = dequantize_nvfp4 (state2 .packed_data , state2 , out_dtype = torch .float32 )
493+ assert torch .equal (out1 , out2 ), "Dequantized outputs differ after serialization"
494+
495+ def test_state_dict_save_load_file (self ):
496+ """Save to file and reload."""
497+ import tempfile
498+
499+ from bitsandbytes .functional import NVFP4QuantState , quantize_nvfp4
500+
501+ torch .manual_seed (99 )
502+ x = torch .randn (128 , dtype = torch .float16 , device = "cuda" )
503+ _ , state = quantize_nvfp4 (x )
504+
505+ with tempfile .NamedTemporaryFile (suffix = ".pt" ) as f :
506+ torch .save (state .state_dict (), f .name )
507+ loaded = torch .load (f .name , weights_only = False )
508+ state2 = NVFP4QuantState .from_state_dict (loaded , device = "cuda" )
509+
510+ assert torch .equal (state .packed_data , state2 .packed_data )
511+ assert state .tensor_scale == state2 .tensor_scale
512+ assert state .dtype == state2 .dtype
513+
514+
341515if __name__ == "__main__" :
342516 pytest .main ([__file__ , "-v" , "-s" ])
0 commit comments