Skip to content

Commit 24a5c54

Browse files
refactor(aggregation): Move utils of aggregation to _utils package (#338)
* Create `_utils` package in `aggregation`. * Move utilitary files in `_utils` and remove their `_utils` suffix. * Create the `_utils` package in `aggregation` in the unit tests, and move tests there. --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com>
1 parent a5c19da commit 24a5c54

24 files changed

+40
-37
lines changed

src/torchjd/aggregation/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from ._check_dependencies import _OptionalDepsNotInstalledError
1+
from ._utils.check_dependencies import (
2+
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
3+
)
24
from .aligned_mtl import AlignedMTL
35
from .bases import Aggregator
46
from .config import ConFIG

src/torchjd/aggregation/_utils/__init__.py

Whitespace-only changes.

src/torchjd/aggregation/_check_dependencies.py renamed to src/torchjd/aggregation/_utils/check_dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from importlib.util import find_spec
22

33

4-
class _OptionalDepsNotInstalledError(ModuleNotFoundError):
4+
class OptionalDepsNotInstalledError(ModuleNotFoundError):
55
pass
66

77

@@ -13,4 +13,4 @@ def check_dependencies_are_installed(dependency_names: list[str]) -> None:
1313
"""
1414

1515
if any(find_spec(name) is None for name in dependency_names):
16-
raise _OptionalDepsNotInstalledError()
16+
raise OptionalDepsNotInstalledError()
File renamed without changes.
File renamed without changes.

src/torchjd/aggregation/_pref_vector_utils.py renamed to src/torchjd/aggregation/_utils/pref_vector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from torch import Tensor
22

3-
from ._str_utils import _vector_to_str
4-
from .bases import _Weighting
5-
from .constant import _ConstantWeighting
3+
from torchjd.aggregation.bases import _Weighting
4+
from torchjd.aggregation.constant import _ConstantWeighting
5+
6+
from .str import vector_to_str
67

78

89
def pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
@@ -28,4 +29,4 @@ def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
2829
if pref_vector is None:
2930
return ""
3031
else:
31-
return f"([{_vector_to_str(pref_vector)}])"
32+
return f"([{vector_to_str(pref_vector)}])"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torch import Tensor
22

33

4-
def _vector_to_str(vector: Tensor) -> str:
4+
def vector_to_str(vector: Tensor) -> str:
55
"""
66
Transforms a Tensor of the form `tensor([1.23456, 1.0, ...])` into a string of the form
77
`1.23, 1., ...`

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from ._gramian_utils import compute_gramian
32-
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
31+
from ._utils.gramian import compute_gramian
32+
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
3333
from .bases import _WeightedAggregator, _Weighting
3434
from .mean import _MeanWeighting
3535

src/torchjd/aggregation/cagrad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._check_dependencies import check_dependencies_are_installed # noqa
1+
from ._utils.check_dependencies import check_dependencies_are_installed
22

33
check_dependencies_are_installed(["cvxpy", "clarabel"])
44

@@ -7,8 +7,8 @@
77
import torch
88
from torch import Tensor
99

10-
from ._gramian_utils import compute_gramian, normalize
11-
from ._non_differentiable import raise_non_differentiable_error
10+
from ._utils.gramian import compute_gramian, normalize
11+
from ._utils.non_differentiable import raise_non_differentiable_error
1212
from .bases import _WeightedAggregator, _Weighting
1313

1414

0 commit comments

Comments
 (0)