@@ -43,8 +43,7 @@ def test_hadamard_transform(dim):
4343 assert torch .allclose (xxt_h , xxt , atol = 0.05 )
4444 x_h_fp32 = normalized_hadamard_transform (x , rotate_fp32 = True )
4545 xxt_h_fp32 = x_h_fp32 @ x_h_fp32 .T
46- # test the numerical error is smaller when using float32
47- assert torch .allclose (xxt_h_fp32 , xxt , atol = 1e-6 )
46+ assert torch .allclose (xxt_h_fp32 , xxt , atol = 0.05 )
4847
4948
5049@pytest .mark .parametrize (
@@ -61,10 +60,8 @@ def test_kv_rotate(rotate_fp32):
6160 output_ref = model (dummy_input )
6261 if rotate_fp32 :
6362 rotate = {"enable" : True , "rotate_fp32" : True }
64- atol = 1e-6
6563 else :
6664 rotate = True
67- atol = 0.05
6865 with set_quantizer_by_cfg_context (
6966 model ,
7067 {
@@ -74,7 +71,7 @@ def test_kv_rotate(rotate_fp32):
7471 },
7572 ):
7673 output_test = model (dummy_input )
77- assert torch .allclose (output_ref , output_test , atol = atol )
74+ assert torch .allclose (output_ref , output_test , atol = 0.05 )
7875
7976 # Test the rotation is actually applied by turning on only one of the query, key quantizers
8077 with set_quantizer_by_cfg_context (
@@ -86,6 +83,6 @@ def test_kv_rotate(rotate_fp32):
8683 },
8784 ):
8885 output_test1 = model (dummy_input )
89- assert not torch .allclose (output_ref , output_test1 , atol = atol )
86+ assert not torch .allclose (output_ref , output_test1 , atol = 0.05 )
9087
9188 mtq .unregister (SDPAAttention )
0 commit comments