@@ -929,3 +929,281 @@ def test_prod_mblock1_matches_reference(self):
929929 f"M_BLOCKS=1 regression: prod does not match reference.\n "
930930 f"Max diff: { (C_prod_cpu - C_direct ).abs ().max ().item ():.6f} "
931931 )
932+
933+
934+ # ===========================================================================
935+ # VQ Codebook Tests: dequant+cuBLAS path and vq_linear dispatch
936+ # ===========================================================================
937+
938+
939+ def _vq_dequant_matmul_ref (A , W , p , codebook = None ):
940+ """Reference: quantize W with VQ, dequantize, then matmul in float32.
941+
942+ Returns (C_ref, packed_flat, absmax_flat, codebook).
943+ """
944+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , dequantize_vq
945+
946+ if codebook is None :
947+ codebook = create_vq_codebook (p , device = "cuda" )
948+ W_gpu = W .half ().cuda ()
949+ packed , absmax , codebook = quantize_vq (W_gpu , p = p , codebook = codebook )
950+ n_total = W .numel ()
951+ W_deq = dequantize_vq (packed , absmax , codebook , p = p , n = n_total , dtype = torch .float16 )
952+ W_deq = W_deq .reshape (W .shape )
953+ A_gpu = A .float ().cuda ()
954+ C_ref = (A_gpu @ W_deq .float ().T ).cpu ()
955+ return C_ref , packed , absmax , codebook
956+
957+
958+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
959+ class TestVQDequantCublas :
960+ """Test VQ dequantize + cuBLAS matmul path (M > 4 fallback in vq_linear)."""
961+
962+ @pytest .mark .parametrize ("p" , [2 , 4 ])
963+ @pytest .mark .parametrize ("M" , [8 , 16 , 32 , 64 ])
964+ def test_dequant_cublas_correctness (self , p , M ):
965+ """Tiled dequant + matmul matches flat dequant + matmul reference."""
966+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , repack_vq
967+
968+ K_dim , N = 512 , 256
969+ torch .manual_seed (42 )
970+
971+ W = torch .randn (N , K_dim )
972+ codebook = create_vq_codebook (p , device = "cuda" )
973+ W_gpu = W .half ().cuda ()
974+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
975+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
976+
977+ # Tiled dequant
978+ W_tiled = torch .ops .bitsandbytes .dequantize_vq_tiled (
979+ packed_tiled , codebook , absmax_tiled , p , K_dim , N , torch .float16 ,
980+ )
981+ W_tiled = W_tiled .reshape (N , K_dim )
982+
983+ A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
984+ C_tiled = (A .float () @ W_tiled .float ().T ).half ()
985+
986+ # Flat dequant reference
987+ from bitsandbytes .functional import dequantize_vq
988+
989+ W_flat = dequantize_vq (packed_flat , absmax_flat , codebook , p = p , n = N * K_dim )
990+ W_flat = W_flat .reshape (N , K_dim )
991+ C_ref = (A .float () @ W_flat .float ().T ).half ()
992+
993+ # Should be bit-identical since same dequant values
994+ diff = (C_tiled .float () - C_ref .float ()).abs ()
995+ scale = C_ref .float ().abs ().clamp (min = 1.0 )
996+ rel_err = (diff / scale ).max ().item ()
997+ assert rel_err < 0.01 , (
998+ f"p={ p } , M={ M } : tiled dequant+matmul vs flat dequant+matmul mismatch. "
999+ f"Max rel err: { rel_err :.6f} "
1000+ )
1001+
1002+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1003+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
1004+ def test_dequant_cublas_dtype (self , p , dtype ):
1005+ """Tiled dequant works with both fp16 and bf16."""
1006+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , repack_vq
1007+
1008+ K_dim , N , M = 512 , 256 , 16
1009+ torch .manual_seed (42 )
1010+
1011+ W = torch .randn (N , K_dim )
1012+ codebook = create_vq_codebook (p , device = "cuda" )
1013+ W_gpu = W .to (dtype ).cuda ()
1014+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
1015+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
1016+
1017+ W_tiled = torch .ops .bitsandbytes .dequantize_vq_tiled (
1018+ packed_tiled , codebook , absmax_tiled , p , K_dim , N , torch .float16 ,
1019+ )
1020+ # Output should be fp16 (codebook is fp16)
1021+ assert W_tiled .dtype == torch .float16 , f"Expected fp16, got { W_tiled .dtype } "
1022+
1023+
1024+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
1025+ class TestVQLinearDispatch :
1026+ """Test the vq_linear dispatch function across the full M range."""
1027+
1028+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1029+ @pytest .mark .parametrize ("M" , [1 , 2 , 3 , 4 ])
1030+ def test_vq_linear_scalar_gemv_path (self , p , M ):
1031+ """vq_linear dispatches to scalar GEMV for M<=4."""
1032+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , repack_vq , vq_linear
1033+
1034+ K_dim , N = 2048 , 512
1035+ torch .manual_seed (42 )
1036+
1037+ W = torch .randn (N , K_dim )
1038+ codebook = create_vq_codebook (p , device = "cuda" )
1039+ W_gpu = W .half ().cuda ()
1040+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
1041+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
1042+
1043+ A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
1044+ C = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N )
1045+
1046+ # Reference: dequant flat + matmul
1047+ from bitsandbytes .functional import dequantize_vq
1048+
1049+ W_deq = dequantize_vq (packed_flat , absmax_flat , codebook , p = p , n = N * K_dim )
1050+ W_deq = W_deq .reshape (N , K_dim )
1051+ C_ref = (A .float () @ W_deq .float ().T ).to (A .dtype )
1052+
1053+ diff = (C .float () - C_ref .float ()).abs ()
1054+ scale = C_ref .float ().abs ().clamp (min = 1.0 )
1055+ rel_err = (diff / scale ).max ().item ()
1056+ assert rel_err < 0.10 , (
1057+ f"p={ p } , M={ M } : vq_linear scalar GEMV path mismatch. Max rel err: { rel_err :.6f} "
1058+ )
1059+
1060+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1061+ @pytest .mark .parametrize ("M" , [8 , 16 , 32 , 64 ])
1062+ def test_vq_linear_cublas_path (self , p , M ):
1063+ """vq_linear dispatches to dequant+cuBLAS for M>4."""
1064+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , repack_vq , vq_linear
1065+
1066+ K_dim , N = 512 , 256
1067+ torch .manual_seed (42 )
1068+
1069+ W = torch .randn (N , K_dim )
1070+ codebook = create_vq_codebook (p , device = "cuda" )
1071+ W_gpu = W .half ().cuda ()
1072+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
1073+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
1074+
1075+ A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
1076+ C = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N )
1077+
1078+ # Reference
1079+ from bitsandbytes .functional import dequantize_vq
1080+
1081+ W_deq = dequantize_vq (packed_flat , absmax_flat , codebook , p = p , n = N * K_dim )
1082+ W_deq = W_deq .reshape (N , K_dim )
1083+ C_ref = (A .float () @ W_deq .float ().T ).to (A .dtype )
1084+
1085+ diff = (C .float () - C_ref .float ()).abs ()
1086+ scale = C_ref .float ().abs ().clamp (min = 1.0 )
1087+ rel_err = (diff / scale ).max ().item ()
1088+ assert rel_err < 0.05 , (
1089+ f"p={ p } , M={ M } : vq_linear cuBLAS path mismatch. Max rel err: { rel_err :.6f} "
1090+ )
1091+
1092+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1093+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
1094+ def test_vq_linear_output_dtype (self , p , dtype ):
1095+ """vq_linear output has same dtype as input A."""
1096+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , repack_vq , vq_linear
1097+
1098+ K_dim , N , M = 512 , 256 , 2
1099+ torch .manual_seed (42 )
1100+
1101+ W = torch .randn (N , K_dim )
1102+ codebook = create_vq_codebook (p , device = "cuda" )
1103+ W_gpu = W .to (dtype ).cuda ()
1104+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
1105+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
1106+
1107+ A = torch .randn (M , K_dim , dtype = dtype , device = "cuda" )
1108+ C = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N )
1109+ assert C .dtype == dtype , f"Expected { dtype } , got { C .dtype } "
1110+
1111+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1112+ @pytest .mark .parametrize (
1113+ "K_dim,N" ,
1114+ [
1115+ (2048 , 5120 ),
1116+ (5120 , 2048 ),
1117+ (2048 , 4096 ),
1118+ ],
1119+ )
1120+ def test_vq_linear_real_shapes (self , p , K_dim , N ):
1121+ """vq_linear works with shapes from real model projections."""
1122+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , repack_vq , vq_linear
1123+
1124+ M = 1
1125+ torch .manual_seed (42 )
1126+
1127+ W = torch .randn (N , K_dim )
1128+ codebook = create_vq_codebook (p , device = "cuda" )
1129+ W_gpu = W .half ().cuda ()
1130+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
1131+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
1132+
1133+ A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
1134+ C = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N )
1135+
1136+ # Reference
1137+ from bitsandbytes .functional import dequantize_vq
1138+
1139+ W_deq = dequantize_vq (packed_flat , absmax_flat , codebook , p = p , n = N * K_dim )
1140+ W_deq = W_deq .reshape (N , K_dim )
1141+ C_ref = (A .float () @ W_deq .float ().T ).to (A .dtype )
1142+
1143+ diff = (C .float () - C_ref .float ()).abs ()
1144+ scale = C_ref .float ().abs ().clamp (min = 1.0 )
1145+ rel_err = (diff / scale ).max ().item ()
1146+ assert rel_err < 0.10 , (
1147+ f"p={ p } , ({ K_dim } ,{ N } ): vq_linear mismatch. Max rel err: { rel_err :.6f} "
1148+ )
1149+
1150+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1151+ def test_vq_linear_workspace (self , p ):
1152+ """vq_linear works with pre-allocated workspace."""
1153+ from bitsandbytes .functional import (
1154+ create_vq_codebook , quantize_vq , repack_vq , vq_linear , vq_linear_workspace ,
1155+ )
1156+
1157+ K_dim , N , M = 512 , 256 , 32
1158+ torch .manual_seed (42 )
1159+
1160+ W = torch .randn (N , K_dim )
1161+ codebook = create_vq_codebook (p , device = "cuda" )
1162+ W_gpu = W .half ().cuda ()
1163+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
1164+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
1165+
1166+ workspace = vq_linear_workspace (M , K_dim , N , p , torch .float16 , torch .device ("cuda" ))
1167+
1168+ A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
1169+ C = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N , workspace = workspace )
1170+
1171+ # Reference (without workspace)
1172+ C_ref = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N )
1173+
1174+ assert torch .equal (C , C_ref ), (
1175+ f"p={ p } : workspace path differs from non-workspace path. "
1176+ f"Max diff: { (C .float () - C_ref .float ()).abs ().max ().item ()} "
1177+ )
1178+
1179+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1180+ def test_vq_linear_preallocated_output (self , p ):
1181+ """vq_linear works with pre-allocated output tensor."""
1182+ from bitsandbytes .functional import create_vq_codebook , quantize_vq , repack_vq , vq_linear
1183+
1184+ K_dim , N , M = 512 , 256 , 2
1185+ torch .manual_seed (42 )
1186+
1187+ W = torch .randn (N , K_dim )
1188+ codebook = create_vq_codebook (p , device = "cuda" )
1189+ W_gpu = W .half ().cuda ()
1190+ packed_flat , absmax_flat , _ = quantize_vq (W_gpu , p = p , codebook = codebook )
1191+ packed_tiled , absmax_tiled = repack_vq (packed_flat , absmax_flat , K_dim , N , p = p )
1192+
1193+ A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
1194+ out = torch .empty (M , N , dtype = torch .float16 , device = "cuda" )
1195+ C = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N , out = out )
1196+
1197+ # Verify it used the pre-allocated buffer
1198+ assert C .data_ptr () == out .data_ptr (), "vq_linear didn't use pre-allocated output"
1199+
1200+ # Verify correctness
1201+ C_ref = vq_linear (A , packed_tiled , absmax_tiled , codebook , p , K_dim , N )
1202+ assert torch .equal (C , C_ref ), "Pre-allocated output differs from fresh output"
1203+
1204+ @pytest .mark .parametrize ("p" , [2 , 4 ])
1205+ @pytest .mark .parametrize ("M" , [5 , 8 , 16 , 32 ])
1206+ @pytest .mark .skip (reason = "Task 5 (VQ MMA kernel) not yet implemented" )
1207+ def test_vq_mma_kernel (self , p , M ):
1208+ """VQ MMA kernel correctness (placeholder for Task 5)."""
1209+ pass
0 commit comments