Skip to content

Commit bc06c81

Browse files
authored
refactor(aggregation): Make _pref_vector_to_weighting check input itself (#265)
* Move pref_vector validation into _pref_vector_to_weighting * Add test_pref_vector_to_weighting_check
1 parent 99d7651 commit bc06c81

File tree

6 files changed

+37
-37
lines changed

6 files changed

+37
-37
lines changed

src/torchjd/aggregation/_pref_vector_utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,6 @@
55
from .constant import _ConstantWeighting
66

77

8-
def _check_pref_vector(pref_vector: Tensor | None) -> None:
9-
"""Checks the correctness of the parameter pref_vector."""
10-
11-
if pref_vector is not None:
12-
if pref_vector.ndim != 1:
13-
raise ValueError(
14-
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
15-
f"{pref_vector.ndim}`."
16-
)
17-
18-
198
def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
209
"""
2110
Returns the weighting associated to a given preference vector, with a fallback to a default
@@ -25,6 +14,11 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -
2514
if pref_vector is None:
2615
return default
2716
else:
17+
if pref_vector.ndim != 1:
18+
raise ValueError(
19+
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
20+
f"{pref_vector.ndim}`."
21+
)
2822
return _ConstantWeighting(pref_vector)
2923

3024

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@
2929
from torch import Tensor
3030
from torch.linalg import LinAlgError
3131

32-
from ._pref_vector_utils import (
33-
_check_pref_vector,
34-
_pref_vector_to_str_suffix,
35-
_pref_vector_to_weighting,
36-
)
32+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
3733
from .bases import _WeightedAggregator, _Weighting
3834
from .mean import _MeanWeighting
3935

@@ -66,7 +62,6 @@ class AlignedMTL(_WeightedAggregator):
6662
"""
6763

6864
def __init__(self, pref_vector: Tensor | None = None):
69-
_check_pref_vector(pref_vector)
7065
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
7166
self._pref_vector = pref_vector
7267

src/torchjd/aggregation/config.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from torchjd.aggregation._pref_vector_utils import (
32-
_check_pref_vector,
33-
_pref_vector_to_str_suffix,
34-
_pref_vector_to_weighting,
35-
)
36-
from torchjd.aggregation.bases import Aggregator
37-
from torchjd.aggregation.sum import _SumWeighting
31+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
32+
from .bases import Aggregator
33+
from .sum import _SumWeighting
3834

3935

4036
class ConFIG(Aggregator):
@@ -66,7 +62,6 @@ class ConFIG(Aggregator):
6662

6763
def __init__(self, pref_vector: Tensor | None = None):
6864
super().__init__()
69-
_check_pref_vector(pref_vector)
7065
self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting())
7166
self._pref_vector = pref_vector
7267

src/torchjd/aggregation/dualproj.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44

55
from ._dual_cone_utils import _project_weights
66
from ._gramian_utils import _compute_regularized_normalized_gramian
7-
from ._pref_vector_utils import (
8-
_check_pref_vector,
9-
_pref_vector_to_str_suffix,
10-
_pref_vector_to_weighting,
11-
)
7+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
128
from .bases import _WeightedAggregator, _Weighting
139
from .mean import _MeanWeighting
1410

@@ -51,7 +47,6 @@ def __init__(
5147
reg_eps: float = 0.0001,
5248
solver: Literal["quadprog"] = "quadprog",
5349
):
54-
_check_pref_vector(pref_vector)
5550
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5651
self._pref_vector = pref_vector
5752

src/torchjd/aggregation/upgrad.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from ._dual_cone_utils import _project_weights
77
from ._gramian_utils import _compute_regularized_normalized_gramian
8-
from ._pref_vector_utils import (
9-
_check_pref_vector,
10-
_pref_vector_to_str_suffix,
11-
_pref_vector_to_weighting,
12-
)
8+
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
139
from .bases import _WeightedAggregator, _Weighting
1410
from .mean import _MeanWeighting
1511

@@ -51,7 +47,6 @@ def __init__(
5147
reg_eps: float = 0.0001,
5248
solver: Literal["quadprog"] = "quadprog",
5349
):
54-
_check_pref_vector(pref_vector)
5550
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5651
self._pref_vector = pref_vector
5752

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from contextlib import nullcontext as does_not_raise
2+
3+
import torch
4+
from pytest import mark, raises
5+
from torch import Tensor
6+
from unit._utils import ExceptionContext
7+
8+
from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting
9+
from torchjd.aggregation.mean import _MeanWeighting
10+
11+
12+
@mark.parametrize(
13+
["pref_vector", "expectation"],
14+
[
15+
(None, does_not_raise()),
16+
(torch.ones([]), raises(ValueError)),
17+
(torch.ones([0]), does_not_raise()),
18+
(torch.ones([1]), does_not_raise()),
19+
(torch.ones([5]), does_not_raise()),
20+
(torch.ones([1, 1]), raises(ValueError)),
21+
(torch.ones([1, 1, 1]), raises(ValueError)),
22+
],
23+
)
24+
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
25+
with expectation:
26+
_ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())

0 commit comments

Comments
 (0)