diff --git a/src/torchjd/aggregation/_pref_vector_utils.py b/src/torchjd/aggregation/_pref_vector_utils.py index 22068cfef..459419a1e 100644 --- a/src/torchjd/aggregation/_pref_vector_utils.py +++ b/src/torchjd/aggregation/_pref_vector_utils.py @@ -5,17 +5,6 @@ from .constant import _ConstantWeighting -def _check_pref_vector(pref_vector: Tensor | None) -> None: - """Checks the correctness of the parameter pref_vector.""" - - if pref_vector is not None: - if pref_vector.ndim != 1: - raise ValueError( - "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " - f"{pref_vector.ndim}`." - ) - - def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting: """ Returns the weighting associated to a given preference vector, with a fallback to a default @@ -25,6 +14,11 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) - if pref_vector is None: return default else: + if pref_vector.ndim != 1: + raise ValueError( + "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " + f"{pref_vector.ndim}`." + ) return _ConstantWeighting(pref_vector) diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/aligned_mtl.py index c9d76d89b..c1423b65e 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/aligned_mtl.py @@ -29,11 +29,7 @@ from torch import Tensor from torch.linalg import LinAlgError -from ._pref_vector_utils import ( - _check_pref_vector, - _pref_vector_to_str_suffix, - _pref_vector_to_weighting, -) +from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -66,7 +62,6 @@ class AlignedMTL(_WeightedAggregator): """ def __init__(self, pref_vector: Tensor | None = None): - _check_pref_vector(pref_vector) weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) self._pref_vector = pref_vector diff --git a/src/torchjd/aggregation/config.py b/src/torchjd/aggregation/config.py index 6ae0f4f3c..816633974 100644 --- a/src/torchjd/aggregation/config.py +++ b/src/torchjd/aggregation/config.py @@ -28,13 +28,9 @@ import torch from torch import Tensor -from torchjd.aggregation._pref_vector_utils import ( - _check_pref_vector, - _pref_vector_to_str_suffix, - _pref_vector_to_weighting, -) -from torchjd.aggregation.bases import Aggregator -from torchjd.aggregation.sum import _SumWeighting +from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting +from .bases import Aggregator +from .sum import _SumWeighting class ConFIG(Aggregator): @@ -66,7 +62,6 @@ class ConFIG(Aggregator): def __init__(self, pref_vector: Tensor | None = None): super().__init__() - _check_pref_vector(pref_vector) self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting()) self._pref_vector = pref_vector diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/dualproj.py index eb55cefe2..9ca2abc23 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/dualproj.py @@ -4,11 +4,7 @@ from ._dual_cone_utils import _project_weights from ._gramian_utils import _compute_regularized_normalized_gramian -from ._pref_vector_utils import ( - _check_pref_vector, - _pref_vector_to_str_suffix, - _pref_vector_to_weighting, -) +from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -51,7 +47,6 @@ def __init__( reg_eps: float = 0.0001, solver: Literal["quadprog"] = "quadprog", ): - _check_pref_vector(pref_vector) weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) self._pref_vector = pref_vector diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index fdc83945f..31892c105 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -5,11 +5,7 @@ from ._dual_cone_utils import _project_weights from ._gramian_utils import _compute_regularized_normalized_gramian -from ._pref_vector_utils import ( - _check_pref_vector, - _pref_vector_to_str_suffix, - _pref_vector_to_weighting, -) +from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -51,7 +47,6 @@ def __init__( reg_eps: float = 0.0001, solver: Literal["quadprog"] = "quadprog", ): - _check_pref_vector(pref_vector) weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) self._pref_vector = pref_vector diff --git a/tests/unit/aggregation/test_pref_vector_utils.py b/tests/unit/aggregation/test_pref_vector_utils.py new file mode 100644 index 000000000..16535ca0b --- /dev/null +++ b/tests/unit/aggregation/test_pref_vector_utils.py @@ -0,0 +1,26 @@ +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises +from torch import Tensor +from unit._utils import ExceptionContext + +from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting +from torchjd.aggregation.mean import _MeanWeighting + + +@mark.parametrize( + ["pref_vector", "expectation"], + [ + (None, does_not_raise()), + (torch.ones([]), raises(ValueError)), + (torch.ones([0]), does_not_raise()), + (torch.ones([1]), does_not_raise()), + (torch.ones([5]), does_not_raise()), + (torch.ones([1, 1]), raises(ValueError)), + (torch.ones([1, 1, 1]), raises(ValueError)), + ], +) +def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext): + with expectation: + _ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())