Skip to content

Commit fc83099

Browse files
Merge branch 'main' into stationarity_property
2 parents 100ef3f + dd2325d commit fc83099

File tree

10 files changed

+63
-58
lines changed

10 files changed

+63
-58
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ changes that do not affect the user.
1717
- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
1818
project onto the dual cone. This may minimally affect the output of these aggregators.
1919

20+
### Fixed
21+
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
22+
practice, this fix should only affect some matrices with extremely large values, which should
23+
not usually happen.
24+
2025
## [0.5.0] - 2025-02-01
2126

2227
### Added

src/torchjd/aggregation/_gramian_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
from torch import Tensor
3-
from torch.linalg import LinAlgError
43

54

65
def _compute_gramian(matrix: Tensor) -> Tensor:
@@ -33,13 +32,7 @@ def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
3332
:math:`n` through the SVD algorithm which is efficient, therefore this is rather fast.
3433
"""
3534

36-
try:
37-
left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False)
38-
except LinAlgError as error: # Not sure if this can happen
39-
raise ValueError(
40-
f"Unexpected failure of the svd computation on matrix {matrix}. Please open an "
41-
"issue on https://github.com/TorchJD/torchjd/issues and paste this error message in it."
42-
) from error
35+
left_unitary_matrix, singular_values, _ = torch.linalg.svd(matrix, full_matrices=False)
4336
max_singular_value = torch.max(singular_values)
4437
if max_singular_value < eps:
4538
scaled_singular_values = torch.zeros_like(singular_values)

src/torchjd/aggregation/_pref_vector_utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,6 @@
55
from .constant import _ConstantWeighting
66

77

8-
def _check_pref_vector(pref_vector: Tensor | None) -> None:
9-
"""Checks the correctness of the parameter pref_vector."""
10-
11-
if pref_vector is not None:
12-
if pref_vector.ndim != 1:
13-
raise ValueError(
14-
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
15-
f"{pref_vector.ndim}`."
16-
)
17-
18-
198
def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
209
"""
2110
Returns the weighting associated to a given preference vector, with a fallback to a default
@@ -25,6 +14,11 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -
2514
if pref_vector is None:
2615
return default
2716
else:
17+
if pref_vector.ndim != 1:
18+
raise ValueError(
19+
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
20+
f"{pref_vector.ndim}`."
21+
)
2822
return _ConstantWeighting(pref_vector)
2923

3024

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,8 @@
2727

2828
import torch
2929
from torch import Tensor
30-
from torch.linalg import LinAlgError
3130

32-
from ._pref_vector_utils import (
33-
_check_pref_vector,
34-
_pref_vector_to_str_suffix,
35-
_pref_vector_to_weighting,
36-
)
31+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
3732
from .bases import _WeightedAggregator, _Weighting
3833
from .mean import _MeanWeighting
3934

@@ -66,7 +61,6 @@ class AlignedMTL(_WeightedAggregator):
6661
"""
6762

6863
def __init__(self, pref_vector: Tensor | None = None):
69-
_check_pref_vector(pref_vector)
7064
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
7165
self._pref_vector = pref_vector
7266

@@ -107,12 +101,7 @@ def forward(self, matrix: Tensor) -> Tensor:
107101
def _compute_balance_transformation(G: Tensor) -> Tensor:
108102
M = G.T @ G
109103

110-
try:
111-
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
112-
except LinAlgError: # This can happen when the matrix has extremely large values
113-
identity = torch.eye(len(M), dtype=M.dtype, device=M.device)
114-
return identity
115-
104+
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
116105
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
117106
rank = sum(lambda_ > tol)
118107

src/torchjd/aggregation/config.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from torchjd.aggregation._pref_vector_utils import (
32-
_check_pref_vector,
33-
_pref_vector_to_str_suffix,
34-
_pref_vector_to_weighting,
35-
)
36-
from torchjd.aggregation.bases import Aggregator
37-
from torchjd.aggregation.sum import _SumWeighting
31+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
32+
from .bases import Aggregator
33+
from .sum import _SumWeighting
3834

3935

4036
class ConFIG(Aggregator):
@@ -66,7 +62,6 @@ class ConFIG(Aggregator):
6662

6763
def __init__(self, pref_vector: Tensor | None = None):
6864
super().__init__()
69-
_check_pref_vector(pref_vector)
7065
self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting())
7166
self._pref_vector = pref_vector
7267

src/torchjd/aggregation/dualproj.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44

55
from ._dual_cone_utils import _project_weights
66
from ._gramian_utils import _compute_regularized_normalized_gramian
7-
from ._pref_vector_utils import (
8-
_check_pref_vector,
9-
_pref_vector_to_str_suffix,
10-
_pref_vector_to_weighting,
11-
)
7+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
128
from .bases import _WeightedAggregator, _Weighting
139
from .mean import _MeanWeighting
1410

@@ -51,7 +47,6 @@ def __init__(
5147
reg_eps: float = 0.0001,
5248
solver: Literal["quadprog"] = "quadprog",
5349
):
54-
_check_pref_vector(pref_vector)
5550
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5651
self._pref_vector = pref_vector
5752

src/torchjd/aggregation/imtl_g.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,9 @@ class _IMTLGWeighting(_Weighting):
3939

4040
def forward(self, matrix: Tensor) -> Tensor:
4141
d = torch.linalg.norm(matrix, dim=1)
42-
43-
try:
44-
v = torch.linalg.pinv(matrix @ matrix.T) @ d
45-
except RuntimeError: # This can happen when the matrix has extremely large values
46-
v = torch.ones(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
47-
42+
v = torch.linalg.pinv(matrix @ matrix.T) @ d
4843
v_sum = v.sum()
44+
4945
if v_sum.abs() < 1e-12:
5046
weights = torch.zeros_like(v)
5147
else:

src/torchjd/aggregation/upgrad.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from ._dual_cone_utils import _project_weights
77
from ._gramian_utils import _compute_regularized_normalized_gramian
8-
from ._pref_vector_utils import (
9-
_check_pref_vector,
10-
_pref_vector_to_str_suffix,
11-
_pref_vector_to_weighting,
12-
)
8+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
139
from .bases import _WeightedAggregator, _Weighting
1410
from .mean import _MeanWeighting
1511

@@ -51,7 +47,6 @@ def __init__(
5147
reg_eps: float = 0.0001,
5248
solver: Literal["quadprog"] = "quadprog",
5349
):
54-
_check_pref_vector(pref_vector)
5550
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5651
self._pref_vector = pref_vector
5752

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from contextlib import nullcontext as does_not_raise
2+
3+
import torch
4+
from pytest import mark, raises
5+
from torch import Tensor
6+
from unit._utils import ExceptionContext
7+
8+
from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting
9+
from torchjd.aggregation.mean import _MeanWeighting
10+
11+
12+
@mark.parametrize(
13+
["pref_vector", "expectation"],
14+
[
15+
(None, does_not_raise()),
16+
(torch.ones([]), raises(ValueError)),
17+
(torch.ones([0]), does_not_raise()),
18+
(torch.ones([1]), does_not_raise()),
19+
(torch.ones([5]), does_not_raise()),
20+
(torch.ones([1, 1]), raises(ValueError)),
21+
(torch.ones([1, 1, 1]), raises(ValueError)),
22+
],
23+
)
24+
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
25+
with expectation:
26+
_ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())

tests/unit/autojac/_transform/test_tensor_dict.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,20 @@ def _assert_class_checks_properly(
110110

111111
def _make_tensor_dict(value_shapes: list[list[int]]) -> dict[Tensor, Tensor]:
112112
return {torch.zeros(key): torch.zeros(value) for key, value in zip(_key_shapes, value_shapes)}
113+
114+
115+
def test_immutability():
116+
"""Tests that it's impossible to modify an existing TensorDict."""
117+
118+
t = Gradients({})
119+
with raises(TypeError):
120+
t[torch.ones(1)] = torch.ones(1)
121+
122+
assert t == Gradients({})
123+
124+
125+
def test_empty_tensor_dict():
126+
"""Tests that it's impossible to instantiate a non-empty EmptyTensorDict."""
127+
128+
with raises(ValueError):
129+
_ = EmptyTensorDict({torch.ones(1): torch.ones(1)})

0 commit comments

Comments
 (0)