diff --git a/CHANGELOG.md b/CHANGELOG.md index c21064e48..faa29afc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,11 @@ changes that do not affect the user. - Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to project onto the dual cone. This may minimally affect the output of these aggregators. +### Fixed +- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In + practice, this fix should only affect some matrices with extremely large values, which should + not usually happen. + ## [0.5.0] - 2025-02-01 ### Added diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/aligned_mtl.py index c1423b65e..b1f5f8628 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/aligned_mtl.py @@ -27,7 +27,6 @@ import torch from torch import Tensor -from torch.linalg import LinAlgError from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting @@ -102,12 +101,7 @@ def forward(self, matrix: Tensor) -> Tensor: def _compute_balance_transformation(G: Tensor) -> Tensor: M = G.T @ G - try: - lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig - except LinAlgError: # This can happen when the matrix has extremely large values - identity = torch.eye(len(M), dtype=M.dtype, device=M.device) - return identity - + lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) diff --git a/src/torchjd/aggregation/imtl_g.py b/src/torchjd/aggregation/imtl_g.py index 291edeb6a..f4fb82322 100644 --- a/src/torchjd/aggregation/imtl_g.py +++ b/src/torchjd/aggregation/imtl_g.py @@ -39,13 +39,9 @@ class _IMTLGWeighting(_Weighting): def forward(self, matrix: Tensor) -> Tensor: d = torch.linalg.norm(matrix, dim=1) - - try: - v = torch.linalg.pinv(matrix @ matrix.T) @ d - except RuntimeError: # This can happen when the matrix has extremely large values - v = torch.ones(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) - + v = torch.linalg.pinv(matrix @ matrix.T) @ d v_sum = v.sum() + if v_sum.abs() < 1e-12: weights = torch.zeros_like(v) else: