Skip to content

Commit 87c0eb1

Browse files
committed
Add try/except around linalg.qr for MPS
1 parent 59cd35e commit 87c0eb1

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

tests/unit/aggregation/_matrix_samplers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,5 +167,14 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None =
167167
# project A onto the orthogonal complement of Q
168168
A_proj = A - Q @ (Q.T @ A)
169169

170-
Q_prime, _ = torch.linalg.qr(A_proj)
170+
try:
171+
Q_prime, _ = torch.linalg.qr(A_proj)
172+
except NotImplementedError:
173+
# This will happen on MPS until they add support for aten::linalg_qr.out
174+
# See status in https://github.com/pytorch/pytorch/issues/141287
175+
# In this case, perform the qr on CPU and move back to the original device
176+
original_device = A_proj.device
177+
Q_prime, _ = torch.linalg.qr(A_proj.to(device="cpu"))
178+
Q_prime = Q_prime.to(device=original_device)
179+
171180
return Q_prime

0 commit comments

Comments
 (0)