Skip to content

Commit 739757d

Browse files
refactor(aggregation): Make utility functions public (#307)
* Make `project_weights` public * Make `pre_vector_to_weighting` and `pref_vector_to_str_suffix` public
1 parent 2c26403 commit 739757d

File tree

8 files changed

+25
-25
lines changed

8 files changed

+25
-25
lines changed

src/torchjd/aggregation/_dual_cone_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import Tensor
77

88

9-
def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
9+
def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
1010
"""
1111
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
1212
rows of a matrix whose Gramian is provided.

src/torchjd/aggregation/_pref_vector_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .constant import _ConstantWeighting
66

77

8-
def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
8+
def pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
99
"""
1010
Returns the weighting associated to a given preference vector, with a fallback to a default
1111
weighting if the preference vector is None.
@@ -22,7 +22,7 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -
2222
return _ConstantWeighting(pref_vector)
2323

2424

25-
def _pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
25+
def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
2626
"""Returns a suffix string containing the representation of the optional preference vector."""
2727

2828
if pref_vector is None:

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
31+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
3232
from .bases import _WeightedAggregator, _Weighting
3333
from .mean import _MeanWeighting
3434

@@ -61,7 +61,7 @@ class AlignedMTL(_WeightedAggregator):
6161
"""
6262

6363
def __init__(self, pref_vector: Tensor | None = None):
64-
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
64+
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
6565
self._pref_vector = pref_vector
6666

6767
super().__init__(weighting=_AlignedMTLWrapper(weighting))
@@ -70,7 +70,7 @@ def __repr__(self) -> str:
7070
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
7171

7272
def __str__(self) -> str:
73-
return f"AlignedMTL{_pref_vector_to_str_suffix(self._pref_vector)}"
73+
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
7474

7575

7676
class _AlignedMTLWrapper(_Weighting):

src/torchjd/aggregation/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
31+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
3232
from .bases import Aggregator
3333
from .sum import _SumWeighting
3434

@@ -62,7 +62,7 @@ class ConFIG(Aggregator):
6262

6363
def __init__(self, pref_vector: Tensor | None = None):
6464
super().__init__()
65-
self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting())
65+
self.weighting = pref_vector_to_weighting(pref_vector, default=_SumWeighting())
6666
self._pref_vector = pref_vector
6767

6868
def forward(self, matrix: Tensor) -> Tensor:
@@ -80,4 +80,4 @@ def __repr__(self) -> str:
8080
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
8181

8282
def __str__(self) -> str:
83-
return f"ConFIG{_pref_vector_to_str_suffix(self._pref_vector)}"
83+
return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"

src/torchjd/aggregation/dualproj.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from torch import Tensor
44

5-
from ._dual_cone_utils import _project_weights
5+
from ._dual_cone_utils import project_weights
66
from ._gramian_utils import _compute_regularized_normalized_gramian
7-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
7+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
88
from .bases import _WeightedAggregator, _Weighting
99
from .mean import _MeanWeighting
1010

@@ -47,7 +47,7 @@ def __init__(
4747
reg_eps: float = 0.0001,
4848
solver: Literal["quadprog"] = "quadprog",
4949
):
50-
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
50+
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5151
self._pref_vector = pref_vector
5252

5353
super().__init__(
@@ -64,7 +64,7 @@ def __repr__(self) -> str:
6464
)
6565

6666
def __str__(self) -> str:
67-
return f"DualProj{_pref_vector_to_str_suffix(self._pref_vector)}"
67+
return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}"
6868

6969

7070
class _DualProjWrapper(_Weighting):
@@ -101,5 +101,5 @@ def __init__(
101101
def forward(self, matrix: Tensor) -> Tensor:
102102
u = self.weighting(matrix)
103103
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
104-
w = _project_weights(u, G, self.solver)
104+
w = project_weights(u, G, self.solver)
105105
return w

src/torchjd/aggregation/upgrad.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import torch
44
from torch import Tensor
55

6-
from ._dual_cone_utils import _project_weights
6+
from ._dual_cone_utils import project_weights
77
from ._gramian_utils import _compute_regularized_normalized_gramian
8-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
8+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
99
from .bases import _WeightedAggregator, _Weighting
1010
from .mean import _MeanWeighting
1111

@@ -47,7 +47,7 @@ def __init__(
4747
reg_eps: float = 0.0001,
4848
solver: Literal["quadprog"] = "quadprog",
4949
):
50-
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
50+
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5151
self._pref_vector = pref_vector
5252

5353
super().__init__(
@@ -64,7 +64,7 @@ def __repr__(self) -> str:
6464
)
6565

6666
def __str__(self) -> str:
67-
return f"UPGrad{_pref_vector_to_str_suffix(self._pref_vector)}"
67+
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"
6868

6969

7070
class _UPGradWrapper(_Weighting):
@@ -97,5 +97,5 @@ def __init__(
9797
def forward(self, matrix: Tensor) -> Tensor:
9898
U = torch.diag(self.weighting(matrix))
9999
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
100-
W = _project_weights(U, G, self.solver)
100+
W = project_weights(U, G, self.solver)
101101
return torch.sum(W, dim=0)

tests/unit/aggregation/test_dual_cone_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pytest import mark, raises
44
from torch.testing import assert_close
55

6-
from torchjd.aggregation._dual_cone_utils import _project_weight_vector, _project_weights
6+
from torchjd.aggregation._dual_cone_utils import _project_weight_vector, project_weights
77

88

99
@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
@@ -33,7 +33,7 @@ def test_solution_weights(shape: tuple[int, int]):
3333
G = J @ J.T
3434
u = torch.rand(shape[0])
3535

36-
w = _project_weights(u, G, "quadprog")
36+
w = project_weights(u, G, "quadprog")
3737
dual_gap = w - u
3838

3939
# Dual feasibility
@@ -64,8 +64,8 @@ def test_tensorization_shape(shape: tuple[int, ...]):
6464

6565
G = matrix @ matrix.T
6666

67-
W_tensor = _project_weights(U_tensor, G, "quadprog")
68-
W_matrix = _project_weights(U_matrix, G, "quadprog")
67+
W_tensor = project_weights(U_tensor, G, "quadprog")
68+
W_matrix = project_weights(U_matrix, G, "quadprog")
6969

7070
assert_close(W_matrix.reshape(shape), W_tensor)
7171

tests/unit/aggregation/test_pref_vector_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import Tensor
66
from unit._utils import ExceptionContext
77

8-
from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting
8+
from torchjd.aggregation._pref_vector_utils import pref_vector_to_weighting
99
from torchjd.aggregation.mean import _MeanWeighting
1010

1111

@@ -23,4 +23,4 @@
2323
)
2424
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
2525
with expectation:
26-
_ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
26+
_ = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())

0 commit comments

Comments
 (0)