Skip to content

Commit 76bb476

Browse files
committed
Add solve_int
1 parent 1c86b79 commit 76bb476

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

src/torchjd/sparse/linalg.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
6+
"""
7+
Solve A X = B where A, B and X have integer dtype.
8+
Return X if such a matrix exists and otherwise None.
9+
"""
10+
11+
A_ = A.to(torch.float64)
12+
B_ = B.to(torch.float64)
13+
14+
try:
15+
X = torch.linalg.solve(A_, B_)
16+
except RuntimeError:
17+
return None
18+
19+
X_rounded = X.round()
20+
if not torch.all(torch.isclose(X, X_rounded, atol=tol)):
21+
return None
22+
23+
# TODO: Verify that the round operation cannot fail
24+
return X_rounded.to(torch.int64)

0 commit comments

Comments
 (0)