@@ -118,11 +118,11 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
118118 print ("qB.t() = " ,qB .t ())
119119 C3 = torch .matmul (A , B .t ())
120120 #pdb.set_trace()
121- C2 = F .gemv_4bit (A , qB .t (), state = state )
121+ C2 = F .gemv_4bit (A , qB .t (), state = state ). bfloat16 ()
122122 #pdb.set_trace()
123123 print ("C3.sum() = " , C3 .sum ())
124124 print ("C2.sum() = " , C2 .sum ())
125- diff = C2 . bfloat16 () - C3
125+ diff = C2 - C3
126126 print ("diff/C2 = " , diff .sum ()/ C3 .sum ())
127127 print (C3 )
128128 print (C2 )
@@ -139,7 +139,7 @@ def test_gemm_4bit(self, device, dim, dtype, storage_type, quant_storage, double
139139 #print("B[0] = ",B[0])
140140 C3 = torch .matmul (A , B .t ())
141141 #pdb.set_trace()
142- C2 = F .gemv_4bit (A , qB .t (), state = state )
142+ C2 = F .gemv_4bit (A , qB .t (), state = state ). bfloat16 ()
143143 pdb .set_trace ()
144144 #print("C3.sum() = ", C3.sum())
145145 #print("C2.sum() = ", C2.sum())
@@ -294,6 +294,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
294294 A = torch .randn (1 , dim , dtype = dtype , device = device )
295295 B = torch .randn (dim * 3 , dim , dtype = dtype , device = device ) / math .sqrt (dim )
296296
297+ #pdb.set_trace()
297298 qB , state = F .quantize_4bit (
298299 B ,
299300 quant_type = storage_type ,
@@ -303,10 +304,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
303304 #pdb.set_trace()
304305 C3 = torch .matmul (A , B .t ())
305306 #pdb.set_trace()
306- C2 = F .gemv_4bit (A , qB .t (), state = state ). bfloat16 ()
307+ C2 = F .gemv_4bit (A , qB .t (), state = state )
307308 #print("C2[0] = ", C2[0])
308309 A .requires_grad = True
309- C1 = bnb . matmul_4bit (A , qB .t (), state ) #.bfloat16( )
310+ C1 = F . gemv_4bit (A , qB .t (), state = state ) #bnb.matmul_4bit(A, qB.t(), state )
310311 #pdb.set_trace()
311312
312313 err1 = (C1 - C2 ).abs ().float ()
0 commit comments