Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_dual_cone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor


def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
"""
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
rows of a matrix whose Gramian is provided.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_pref_vector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .constant import _ConstantWeighting


def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
def pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
"""
Returns the weighting associated to a given preference vector, with a fallback to a default
weighting if the preference vector is None.
Expand All @@ -22,7 +22,7 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -
return _ConstantWeighting(pref_vector)


def _pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
"""Returns a suffix string containing the representation of the optional preference vector."""

if pref_vector is None:
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch
from torch import Tensor

from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting

Expand Down Expand Up @@ -61,7 +61,7 @@ class AlignedMTL(_WeightedAggregator):
"""

def __init__(self, pref_vector: Tensor | None = None):
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
self._pref_vector = pref_vector

super().__init__(weighting=_AlignedMTLWrapper(weighting))
Expand All @@ -70,7 +70,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"

def __str__(self) -> str:
return f"AlignedMTL{_pref_vector_to_str_suffix(self._pref_vector)}"
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"


class _AlignedMTLWrapper(_Weighting):
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch
from torch import Tensor

from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
from .bases import Aggregator
from .sum import _SumWeighting

Expand Down Expand Up @@ -62,7 +62,7 @@ class ConFIG(Aggregator):

def __init__(self, pref_vector: Tensor | None = None):
super().__init__()
self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting())
self.weighting = pref_vector_to_weighting(pref_vector, default=_SumWeighting())
self._pref_vector = pref_vector

def forward(self, matrix: Tensor) -> Tensor:
Expand All @@ -80,4 +80,4 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"

def __str__(self) -> str:
return f"ConFIG{_pref_vector_to_str_suffix(self._pref_vector)}"
return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from torch import Tensor

from ._dual_cone_utils import _project_weights
from ._dual_cone_utils import project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting

Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
):
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
self._pref_vector = pref_vector

super().__init__(
Expand All @@ -64,7 +64,7 @@ def __repr__(self) -> str:
)

def __str__(self) -> str:
return f"DualProj{_pref_vector_to_str_suffix(self._pref_vector)}"
return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}"


class _DualProjWrapper(_Weighting):
Expand Down Expand Up @@ -101,5 +101,5 @@ def __init__(
def forward(self, matrix: Tensor) -> Tensor:
u = self.weighting(matrix)
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
w = _project_weights(u, G, self.solver)
w = project_weights(u, G, self.solver)
return w
10 changes: 5 additions & 5 deletions src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from torch import Tensor

from ._dual_cone_utils import _project_weights
from ._dual_cone_utils import project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting

Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
):
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
self._pref_vector = pref_vector

super().__init__(
Expand All @@ -64,7 +64,7 @@ def __repr__(self) -> str:
)

def __str__(self) -> str:
return f"UPGrad{_pref_vector_to_str_suffix(self._pref_vector)}"
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"


class _UPGradWrapper(_Weighting):
Expand Down Expand Up @@ -97,5 +97,5 @@ def __init__(
def forward(self, matrix: Tensor) -> Tensor:
U = torch.diag(self.weighting(matrix))
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
W = _project_weights(U, G, self.solver)
W = project_weights(U, G, self.solver)
return torch.sum(W, dim=0)
8 changes: 4 additions & 4 deletions tests/unit/aggregation/test_dual_cone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytest import mark, raises
from torch.testing import assert_close

from torchjd.aggregation._dual_cone_utils import _project_weight_vector, _project_weights
from torchjd.aggregation._dual_cone_utils import _project_weight_vector, project_weights


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
Expand Down Expand Up @@ -33,7 +33,7 @@ def test_solution_weights(shape: tuple[int, int]):
G = J @ J.T
u = torch.rand(shape[0])

w = _project_weights(u, G, "quadprog")
w = project_weights(u, G, "quadprog")
dual_gap = w - u

# Dual feasibility
Expand Down Expand Up @@ -64,8 +64,8 @@ def test_tensorization_shape(shape: tuple[int, ...]):

G = matrix @ matrix.T

W_tensor = _project_weights(U_tensor, G, "quadprog")
W_matrix = _project_weights(U_matrix, G, "quadprog")
W_tensor = project_weights(U_tensor, G, "quadprog")
W_matrix = project_weights(U_matrix, G, "quadprog")

assert_close(W_matrix.reshape(shape), W_tensor)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_pref_vector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor
from unit._utils import ExceptionContext

from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting
from torchjd.aggregation._pref_vector_utils import pref_vector_to_weighting
from torchjd.aggregation.mean import _MeanWeighting


Expand All @@ -23,4 +23,4 @@
)
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
with expectation:
_ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
_ = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())