Skip to content

Commit 5ba58fd

Browse files
committed
Make pre_vector_to_weighting and pref_vector_to_str_suffix public
1 parent 4806c88 commit 5ba58fd

6 files changed

Lines changed: 16 additions & 16 deletions

File tree

src/torchjd/aggregation/_pref_vector_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .constant import _ConstantWeighting
66

77

8-
def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
8+
def pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
99
"""
1010
Returns the weighting associated to a given preference vector, with a fallback to a default
1111
weighting if the preference vector is None.
@@ -22,7 +22,7 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -
2222
return _ConstantWeighting(pref_vector)
2323

2424

25-
def _pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
25+
def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
2626
"""Returns a suffix string containing the representation of the optional preference vector."""
2727

2828
if pref_vector is None:

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
31+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
3232
from .bases import _WeightedAggregator, _Weighting
3333
from .mean import _MeanWeighting
3434

@@ -61,7 +61,7 @@ class AlignedMTL(_WeightedAggregator):
6161
"""
6262

6363
def __init__(self, pref_vector: Tensor | None = None):
64-
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
64+
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
6565
self._pref_vector = pref_vector
6666

6767
super().__init__(weighting=_AlignedMTLWrapper(weighting))
@@ -70,7 +70,7 @@ def __repr__(self) -> str:
7070
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
7171

7272
def __str__(self) -> str:
73-
return f"AlignedMTL{_pref_vector_to_str_suffix(self._pref_vector)}"
73+
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
7474

7575

7676
class _AlignedMTLWrapper(_Weighting):

src/torchjd/aggregation/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
31+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
3232
from .bases import Aggregator
3333
from .sum import _SumWeighting
3434

@@ -62,7 +62,7 @@ class ConFIG(Aggregator):
6262

6363
def __init__(self, pref_vector: Tensor | None = None):
6464
super().__init__()
65-
self.weighting = _pref_vector_to_weighting(pref_vector, default=_SumWeighting())
65+
self.weighting = pref_vector_to_weighting(pref_vector, default=_SumWeighting())
6666
self._pref_vector = pref_vector
6767

6868
def forward(self, matrix: Tensor) -> Tensor:
@@ -80,4 +80,4 @@ def __repr__(self) -> str:
8080
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
8181

8282
def __str__(self) -> str:
83-
return f"ConFIG{_pref_vector_to_str_suffix(self._pref_vector)}"
83+
return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"

src/torchjd/aggregation/dualproj.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._dual_cone_utils import project_weights
66
from ._gramian_utils import _compute_regularized_normalized_gramian
7-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
7+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
88
from .bases import _WeightedAggregator, _Weighting
99
from .mean import _MeanWeighting
1010

@@ -47,7 +47,7 @@ def __init__(
4747
reg_eps: float = 0.0001,
4848
solver: Literal["quadprog"] = "quadprog",
4949
):
50-
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
50+
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5151
self._pref_vector = pref_vector
5252

5353
super().__init__(
@@ -64,7 +64,7 @@ def __repr__(self) -> str:
6464
)
6565

6666
def __str__(self) -> str:
67-
return f"DualProj{_pref_vector_to_str_suffix(self._pref_vector)}"
67+
return f"DualProj{pref_vector_to_str_suffix(self._pref_vector)}"
6868

6969

7070
class _DualProjWrapper(_Weighting):

src/torchjd/aggregation/upgrad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ._dual_cone_utils import project_weights
77
from ._gramian_utils import _compute_regularized_normalized_gramian
8-
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
8+
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
99
from .bases import _WeightedAggregator, _Weighting
1010
from .mean import _MeanWeighting
1111

@@ -47,7 +47,7 @@ def __init__(
4747
reg_eps: float = 0.0001,
4848
solver: Literal["quadprog"] = "quadprog",
4949
):
50-
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
50+
weighting = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5151
self._pref_vector = pref_vector
5252

5353
super().__init__(
@@ -64,7 +64,7 @@ def __repr__(self) -> str:
6464
)
6565

6666
def __str__(self) -> str:
67-
return f"UPGrad{_pref_vector_to_str_suffix(self._pref_vector)}"
67+
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"
6868

6969

7070
class _UPGradWrapper(_Weighting):

tests/unit/aggregation/test_pref_vector_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import Tensor
66
from unit._utils import ExceptionContext
77

8-
from torchjd.aggregation._pref_vector_utils import _pref_vector_to_weighting
8+
from torchjd.aggregation._pref_vector_utils import pref_vector_to_weighting
99
from torchjd.aggregation.mean import _MeanWeighting
1010

1111

@@ -23,4 +23,4 @@
2323
)
2424
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):
2525
with expectation:
26-
_ = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
26+
_ = pref_vector_to_weighting(pref_vector, default=_MeanWeighting())

0 commit comments

Comments
 (0)