diff --git a/src/torchjd/aggregation/_dual_cone_utils.py b/src/torchjd/aggregation/_dual_cone_utils.py index 3dc85373..539685be 100644 --- a/src/torchjd/aggregation/_dual_cone_utils.py +++ b/src/torchjd/aggregation/_dual_cone_utils.py @@ -6,7 +6,7 @@ from torch import Tensor -def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: +def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: """ Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the rows of a matrix whose Gramian is provided. diff --git a/src/torchjd/aggregation/_pref_vector_utils.py b/src/torchjd/aggregation/_pref_vector_utils.py index 459419a1..ce441e5d 100644 --- a/src/torchjd/aggregation/_pref_vector_utils.py +++ b/src/torchjd/aggregation/_pref_vector_utils.py @@ -5,7 +5,7 @@ from .constant import _ConstantWeighting -def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting: +def pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting: """ Returns the weighting associated to a given preference vector, with a fallback to a default weighting if the preference vector is None. @@ -22,7 +22,7 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) - return _ConstantWeighting(pref_vector) -def _pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: +def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: """Returns a suffix string containing the representation of the optional preference vector.""" if pref_vector is None: diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/aligned_mtl.py index b1f5f862..6f54a5ad 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/aligned_mtl.py @@ -28,7 +28,7 @@ import torch from torch import Tensor -from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting +from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -61,7 +61,7 @@ class AlignedMTL(_WeightedAggregator): """ def __init__(self, pref_vector: Tensor | None = None): - weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) + weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) self._pref_vector = pref_vector super().__init__(weighting=_AlignedMTLWrapper(weighting)) @@ -70,7 +70,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" def __str__(self) -> str: - return f"AlignedMTL{_pref_vector_to_str_suffix(self._pref_vector)}" + return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" class _AlignedMTLWrapper(_Weighting): diff --git a/src/torchjd/aggregation/config.py b/src/torchjd/aggregation/config.py index 12502c8e..c7876d0c 100644 --- a/src/torchjd/aggregation/config.py +++ b/src/torchjd/aggregation/config.py @@ -28,7 +28,7 @@ import torch from torch import Tensor -from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting +from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import Aggregator from .sum import _SumWeighting @@ -62,7 +62,7 @@ class ConFIG(Aggregator): def __init__(self, pref_vector: Tensor | None = None): super().__init__() - self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting()) + self.weighting = pref_vector_to_weighting(pref_vector, default=_SumWeighting()) self._pref_vector = pref_vector def forward(self, matrix: Tensor) -> Tensor: @@ -80,4 +80,4 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" def __str__(self) -> str: - return f"ConFIG{_pref_vector_to_str_suffix(self._pref_vector)}" + return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}" diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/dualproj.py index 9ca2abc2..d0e20d75 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/dualproj.py @@ -2,9 +2,9 @@ from torch import Tensor -from ._dual_cone_utils import _project_weights +from ._dual_cone_utils import project_weights from ._gramian_utils import _compute_regularized_normalized_gramian -from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting +from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -47,7 +47,7 @@ def __init__( reg_eps: float = 0.0001, solver: Literal["quadprog"] = "quadprog", ): - weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) + weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) self._pref_vector = pref_vector super().__init__( @@ -64,7 +64,7 @@ def __repr__(self) -> str: ) def __str__(self) -> str: - return f"DualProj{_pref_vector_to_str_suffix(self._pref_vector)}" + return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}" class _DualProjWrapper(_Weighting): @@ -101,5 +101,5 @@ def __init__( def forward(self, matrix: Tensor) -> Tensor: u = self.weighting(matrix) G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) - w = _project_weights(u, G, self.solver) + w = project_weights(u, G, self.solver) return w diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index 31892c10..fe258402 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -3,9 +3,9 @@ import torch from torch import Tensor -from ._dual_cone_utils import _project_weights +from ._dual_cone_utils import project_weights from ._gramian_utils import _compute_regularized_normalized_gramian -from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting +from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -47,7 +47,7 @@ def __init__( reg_eps: float = 0.0001, solver: Literal["quadprog"] = "quadprog", ): - weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) + weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) self._pref_vector = pref_vector super().__init__( @@ -64,7 +64,7 @@ def __repr__(self) -> str: ) def __str__(self) -> str: - return f"UPGrad{_pref_vector_to_str_suffix(self._pref_vector)}" + return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}" class _UPGradWrapper(_Weighting): @@ -97,5 +97,5 @@ def __init__( def forward(self, matrix: Tensor) -> Tensor: U = torch.diag(self.weighting(matrix)) G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) - W = _project_weights(U, G, self.solver) + W = project_weights(U, G, self.solver) return torch.sum(W, dim=0) diff --git a/tests/unit/aggregation/test_dual_cone_utils.py b/tests/unit/aggregation/test_dual_cone_utils.py index 506f63ec..bf192323 100644 --- a/tests/unit/aggregation/test_dual_cone_utils.py +++ b/tests/unit/aggregation/test_dual_cone_utils.py @@ -3,7 +3,7 @@ from pytest import mark, raises from torch.testing import assert_close -from torchjd.aggregation._dual_cone_utils import _project_weight_vector, _project_weights +from torchjd.aggregation._dual_cone_utils import _project_weight_vector, project_weights @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) @@ -33,7 +33,7 @@ def test_solution_weights(shape: tuple[int, int]): G = J @ J.T u = torch.rand(shape[0]) - w = _project_weights(u, G, "quadprog") + w = project_weights(u, G, "quadprog") dual_gap = w - u # Dual feasibility @@ -64,8 +64,8 @@ def test_tensorization_shape(shape: tuple[int, ...]): G = matrix @ matrix.T - W_tensor = _project_weights(U_tensor, G, "quadprog") - W_matrix = _project_weights(U_matrix, G, "quadprog") + W_tensor = project_weights(U_tensor, G, "quadprog") + W_matrix = project_weights(U_matrix, G, "quadprog") assert_close(W_matrix.reshape(shape), W_tensor) diff --git a/tests/unit/aggregation/test_pref_vector_utils.py b/tests/unit/aggregation/test_pref_vector_utils.py index 16535ca0..aa904f5d 100644 --- a/tests/unit/aggregation/test_pref_vector_utils.py +++ b/tests/unit/aggregation/test_pref_vector_utils.py @@ -5,7 +5,7 @@ from torch import Tensor from unit._utils import ExceptionContext -from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting +from torchjd.aggregation._pref_vector_utils import pref_vector_to_weighting from torchjd.aggregation.mean import _MeanWeighting @@ -23,4 +23,4 @@ ) def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext): with expectation: - _ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting()) + _ = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())