Skip to content

Commit 24572c8

Browse files
committed
fix(tests): skip torch Gemm on CPU half-precision
ATen `addmm`/`baddbmm` does not support `float16`/`bfloat16` on CPU.
1 parent c13b378 commit 24572c8

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

tests/test_gemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def test_gemm(
6464
if implementation_index == 1 and dtype in (torch.float16, torch.bfloat16):
6565
pytest.skip("cuBLASLt half-precision exceeds current tolerances")
6666

67+
if (
68+
implementation_index == 2
69+
and device == "cpu"
70+
and dtype in (torch.float16, torch.bfloat16)
71+
):
72+
pytest.skip("ATen CPU `addmm`/`baddbmm` does not support half-precision")
73+
6774
a = randn_strided(a_shape, a_strides, dtype=dtype, device=device)
6875
b = randn_strided(b_shape, b_strides, dtype=dtype, device=device)
6976

0 commit comments

Comments
 (0)