Skip to content

Commit b88236a

Browse files
authored
Merge pull request #8 from xiaolil1/jiqing
fix tests
2 parents 1e0f661 + 99698d2 commit b88236a

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

tests/test_modules.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
155155
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
156156
o1 = mlp(b1)
157157
assert o1.dtype == torch.float16
158-
if threshold > 0:
158+
if threshold > 0 and device not in ("cpu", "xpu"):
159159
assert mlp.fc1.state.idx is not None
160-
if threshold > 0:
161160
assert mlp.fc2.state.idx is not None
162161

163162
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)
@@ -166,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
166165
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
167166
o1 = mlp(b1)
168167
assert o1.dtype == torch.float16
169-
if threshold > 0:
168+
if threshold > 0 and device not in ("cpu", "xpu"):
170169
assert mlp.fc1.state.idx is not None
171-
if threshold > 0:
172170
assert mlp.fc2.state.idx is not None
173171
assert mlp.fc1.weight.dtype == torch.int8
174172
assert mlp.fc2.weight.dtype == torch.int8
@@ -188,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
188186
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
189187
o1 = mlp(b1)
190188
assert o1.dtype == torch.float16
191-
if threshold > 0:
189+
if threshold > 0 and device not in ("cpu", "xpu"):
192190
assert mlp.fc1.state.idx is not None
193-
if threshold > 0:
194191
assert mlp.fc2.state.idx is not None
195192
assert mlp.fc1.weight.dtype == torch.int8
196193
assert mlp.fc2.weight.dtype == torch.int8
@@ -210,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
210207
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
211208
o1 = mlp(b1)
212209
assert o1.dtype == torch.float16
213-
if threshold > 0:
210+
if threshold > 0 and device not in ("cpu", "xpu"):
214211
assert mlp.fc1.state.idx is not None
215-
if threshold > 0:
216212
assert mlp.fc2.state.idx is not None
217213

218214
assert mlp.fc1.weight.dtype == torch.int8

0 commit comments

Comments
 (0)