|
27 | 27 |
|
28 | 28 | import torch |
29 | 29 | from torch import Tensor |
30 | | -from torch.linalg import LinAlgError |
31 | 30 |
|
32 | | -from ._pref_vector_utils import ( |
33 | | - _check_pref_vector, |
34 | | - _pref_vector_to_str_suffix, |
35 | | - _pref_vector_to_weighting, |
36 | | -) |
| 31 | +from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting |
37 | 32 | from .bases import _WeightedAggregator, _Weighting |
38 | 33 | from .mean import _MeanWeighting |
39 | 34 |
|
@@ -66,7 +61,6 @@ class AlignedMTL(_WeightedAggregator): |
66 | 61 | """ |
67 | 62 |
|
68 | 63 | def __init__(self, pref_vector: Tensor | None = None): |
69 | | - _check_pref_vector(pref_vector) |
70 | 64 | weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) |
71 | 65 | self._pref_vector = pref_vector |
72 | 66 |
|
@@ -107,12 +101,7 @@ def forward(self, matrix: Tensor) -> Tensor: |
107 | 101 | def _compute_balance_transformation(G: Tensor) -> Tensor: |
108 | 102 | M = G.T @ G |
109 | 103 |
|
110 | | - try: |
111 | | - lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig |
112 | | - except LinAlgError: # This can happen when the matrix has extremely large values |
113 | | - identity = torch.eye(len(M), dtype=M.dtype, device=M.device) |
114 | | - return identity |
115 | | - |
| 104 | + lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig |
116 | 105 | tol = torch.max(lambda_) * len(M) * torch.finfo().eps |
117 | 106 | rank = sum(lambda_ > tol) |
118 | 107 |
|
|
0 commit comments