diff --git a/tests/test_modules.py b/tests/test_modules.py index b8f7c4f9f..beaea6e12 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -155,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device) @@ -166,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -188,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -210,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8