@@ -155,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
155155 b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch .float16 )
156156 o1 = mlp (b1 )
157157 assert o1 .dtype == torch .float16
158- if threshold > 0 :
158+ if threshold > 0 and device not in ( "cpu" , "xpu" ) :
159159 assert mlp .fc1 .state .idx is not None
160- if threshold > 0 :
161160 assert mlp .fc2 .state .idx is not None
162161
163162 mlp = MLP8bit (32 , 64 , threshold = threshold , has_fp16_weights = False ).half ().to (device )
@@ -166,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
166165 b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch .float16 )
167166 o1 = mlp (b1 )
168167 assert o1 .dtype == torch .float16
169- if threshold > 0 :
168+ if threshold > 0 and device not in ( "cpu" , "xpu" ) :
170169 assert mlp .fc1 .state .idx is not None
171- if threshold > 0 :
172170 assert mlp .fc2 .state .idx is not None
173171 assert mlp .fc1 .weight .dtype == torch .int8
174172 assert mlp .fc2 .weight .dtype == torch .int8
@@ -188,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
188186 b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch .float16 )
189187 o1 = mlp (b1 )
190188 assert o1 .dtype == torch .float16
191- if threshold > 0 :
189+ if threshold > 0 and device not in ( "cpu" , "xpu" ) :
192190 assert mlp .fc1 .state .idx is not None
193- if threshold > 0 :
194191 assert mlp .fc2 .state .idx is not None
195192 assert mlp .fc1 .weight .dtype == torch .int8
196193 assert mlp .fc2 .weight .dtype == torch .int8
@@ -210,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
210207 b1 = torch .randn (16 , 8 , 32 , device = device , dtype = torch .float16 )
211208 o1 = mlp (b1 )
212209 assert o1 .dtype == torch .float16
213- if threshold > 0 :
210+ if threshold > 0 and device not in ( "cpu" , "xpu" ) :
214211 assert mlp .fc1 .state .idx is not None
215- if threshold > 0 :
216212 assert mlp .fc2 .state .idx is not None
217213
218214 assert mlp .fc1 .weight .dtype == torch .int8
0 commit comments