diff --git a/tests/test_modules.py b/tests/test_modules.py index 1081b4b9a..b8f7c4f9f 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -143,7 +143,7 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0 and device.type not in ("cpu", "xpu"): + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None assert mlp.fc2.state.idx is not None