We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
solve_int
1 parent 1c86b79 commit 76bb476Copy full SHA for 76bb476
src/torchjd/sparse/linalg.py
@@ -0,0 +1,24 @@
1
+import torch
2
+from torch import Tensor
3
+
4
5
+def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None:
6
+ """
7
+ Solve A X = B where A, B and X have integer dtype.
8
+ Return X if such a matrix exists and otherwise None.
9
10
11
+ A_ = A.to(torch.float64)
12
+ B_ = B.to(torch.float64)
13
14
+ try:
15
+ X = torch.linalg.solve(A_, B_)
16
+ except RuntimeError:
17
+ return None
18
19
+ X_rounded = X.round()
20
+ if not torch.all(torch.isclose(X, X_rounded, atol=tol)):
21
22
23
+ # TODO: Verify that the round operation cannot fail
24
+ return X_rounded.to(torch.int64)
0 commit comments