Skip to content

Commit 6da02d4

Browse files
authored
fix(aggregation): Remove arbitrary Linalg exception handling (#269)
* Remove catching of RuntimeError (should have been more precisely LinalgError) in IMTLG when pinv fails * Remove catching of LinalgError in AlignedMTL when eigh fails * Add changelog entry
1 parent bc06c81 commit 6da02d4

File tree

3 files changed

+8
-13
lines changed

3 files changed

+8
-13
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ changes that do not affect the user.
1717
- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
1818
project onto the dual cone. This may minimally affect the output of these aggregators.
1919

20+
### Fixed
21+
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
22+
practice, this fix should only affect some matrices with extremely large values, which should
23+
not usually happen.
24+
2025
## [0.5.0] - 2025-02-01
2126

2227
### Added

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import torch
2929
from torch import Tensor
30-
from torch.linalg import LinAlgError
3130

3231
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
3332
from .bases import _WeightedAggregator, _Weighting
@@ -102,12 +101,7 @@ def forward(self, matrix: Tensor) -> Tensor:
102101
def _compute_balance_transformation(G: Tensor) -> Tensor:
103102
M = G.T @ G
104103

105-
try:
106-
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
107-
except LinAlgError: # This can happen when the matrix has extremely large values
108-
identity = torch.eye(len(M), dtype=M.dtype, device=M.device)
109-
return identity
110-
104+
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
111105
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
112106
rank = sum(lambda_ > tol)
113107

src/torchjd/aggregation/imtl_g.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,9 @@ class _IMTLGWeighting(_Weighting):
3939

4040
def forward(self, matrix: Tensor) -> Tensor:
4141
d = torch.linalg.norm(matrix, dim=1)
42-
43-
try:
44-
v = torch.linalg.pinv(matrix @ matrix.T) @ d
45-
except RuntimeError: # This can happen when the matrix has extremely large values
46-
v = torch.ones(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
47-
42+
v = torch.linalg.pinv(matrix @ matrix.T) @ d
4843
v_sum = v.sum()
44+
4945
if v_sum.abs() < 1e-12:
5046
weights = torch.zeros_like(v)
5147
else:

0 commit comments

Comments
 (0)