File tree Expand file tree Collapse file tree 3 files changed +8
-13
lines changed
Expand file tree Collapse file tree 3 files changed +8
-13
lines changed Original file line number Diff line number Diff line change @@ -17,6 +17,11 @@ changes that do not affect the user.
1717- Refactored the underlying optimization problem that ` UPGrad ` and ` DualProj ` have to solve to
1818 project onto the dual cone. This may minimally affect the output of these aggregators.
1919
20+ ### Fixed
21+ - Removed arbitrary exception handling in ` IMTLG ` and ` AlignedMTL ` when the computation fails. In
22+ practice, this fix should only affect some matrices with extremely large values, which should
23+ not usually happen.
24+
2025## [ 0.5.0] - 2025-02-01
2126
2227### Added
Original file line number Diff line number Diff line change 2727
2828import torch
2929from torch import Tensor
30- from torch .linalg import LinAlgError
3130
3231from ._pref_vector_utils import _pref_vector_to_str_suffix , _pref_vector_to_weighting
3332from .bases import _WeightedAggregator , _Weighting
@@ -102,12 +101,7 @@ def forward(self, matrix: Tensor) -> Tensor:
102101 def _compute_balance_transformation (G : Tensor ) -> Tensor :
103102 M = G .T @ G
104103
105- try :
106- lambda_ , V = torch .linalg .eigh (M , UPLO = "U" ) # More modern equivalent to torch.symeig
107- except LinAlgError : # This can happen when the matrix has extremely large values
108- identity = torch .eye (len (M ), dtype = M .dtype , device = M .device )
109- return identity
110-
104+ lambda_ , V = torch .linalg .eigh (M , UPLO = "U" ) # More modern equivalent to torch.symeig
111105 tol = torch .max (lambda_ ) * len (M ) * torch .finfo ().eps
112106 rank = sum (lambda_ > tol )
113107
Original file line number Diff line number Diff line change @@ -39,13 +39,9 @@ class _IMTLGWeighting(_Weighting):
3939
4040 def forward (self , matrix : Tensor ) -> Tensor :
4141 d = torch .linalg .norm (matrix , dim = 1 )
42-
43- try :
44- v = torch .linalg .pinv (matrix @ matrix .T ) @ d
45- except RuntimeError : # This can happen when the matrix has extremely large values
46- v = torch .ones (matrix .shape [0 ], device = matrix .device , dtype = matrix .dtype )
47-
42+ v = torch .linalg .pinv (matrix @ matrix .T ) @ d
4843 v_sum = v .sum ()
44+
4945 if v_sum .abs () < 1e-12 :
5046 weights = torch .zeros_like (v )
5147 else :
You can’t perform that action at this time.
0 commit comments