Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ changes that do not affect the user.
- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
project onto the dual cone. This may minimally affect the output of these aggregators.

### Fixed
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
practice, this fix should only affect some matrices with extremely large values, which should
not usually happen.

## [0.5.0] - 2025-02-01

### Added
Expand Down
8 changes: 1 addition & 7 deletions src/torchjd/aggregation/aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import torch
from torch import Tensor
from torch.linalg import LinAlgError

from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
from .bases import _WeightedAggregator, _Weighting
Expand Down Expand Up @@ -102,12 +101,7 @@ def forward(self, matrix: Tensor) -> Tensor:
def _compute_balance_transformation(G: Tensor) -> Tensor:
M = G.T @ G

try:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
except LinAlgError: # This can happen when the matrix has extremely large values
identity = torch.eye(len(M), dtype=M.dtype, device=M.device)
return identity

lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)

Expand Down
8 changes: 2 additions & 6 deletions src/torchjd/aggregation/imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@ class _IMTLGWeighting(_Weighting):

def forward(self, matrix: Tensor) -> Tensor:
d = torch.linalg.norm(matrix, dim=1)

try:
v = torch.linalg.pinv(matrix @ matrix.T) @ d
except RuntimeError: # This can happen when the matrix has extremely large values
v = torch.ones(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)

v = torch.linalg.pinv(matrix @ matrix.T) @ d
v_sum = v.sum()

if v_sum.abs() < 1e-12:
weights = torch.zeros_like(v)
else:
Expand Down