Skip to content

Commit 9531e3c

Browse files
committed
Improve imports
1 parent 9ee68f8 commit 9531e3c

File tree

14 files changed

+27
-15
lines changed

14 files changed

+27
-15
lines changed

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
from torch import Tensor, nn
44

5-
from torchjd._utils import compute_gramian
5+
from torchjd._utils import Matrix, PSDMatrix, compute_gramian
66

7-
from .._utils.compute_gramian import Matrix, PSDMatrix
87
from ._weighting_bases import Weighting
98

109

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
import torch
2929
from torch import Tensor
3030

31-
from .._utils.compute_gramian import PSDMatrix
31+
from torchjd._utils.compute_gramian import PSDMatrix
32+
3233
from ._aggregator_bases import GramianWeightedAggregator
3334
from ._mean import MeanWeighting
3435
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting

src/torchjd/aggregation/_cagrad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import cast
22

3-
from .._utils.compute_gramian import PSDMatrix
3+
from torchjd._utils.compute_gramian import PSDMatrix
4+
45
from ._utils.check_dependencies import check_dependencies_are_installed
56
from ._weighting_bases import Weighting
67

src/torchjd/aggregation/_constant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from torch import Tensor
22

3-
from .._utils.compute_gramian import Matrix
3+
from torchjd._utils.compute_gramian import Matrix
4+
45
from ._aggregator_bases import WeightedAggregator
56
from ._utils.str import vector_to_str
67
from ._weighting_bases import Weighting

src/torchjd/aggregation/_dualproj.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from torch import Tensor
44

5-
from .._utils.compute_gramian import PSDMatrix
5+
from torchjd._utils.compute_gramian import PSDMatrix
6+
67
from ._aggregator_bases import GramianWeightedAggregator
78
from ._mean import MeanWeighting
89
from ._utils.dual_cone import project_weights

src/torchjd/aggregation/_imtl_g.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
from torch import Tensor
33

4-
from .._utils.compute_gramian import PSDMatrix
4+
from torchjd._utils.compute_gramian import PSDMatrix
5+
56
from ._aggregator_bases import GramianWeightedAggregator
67
from ._utils.non_differentiable import raise_non_differentiable_error
78
from ._weighting_bases import Weighting

src/torchjd/aggregation/_krum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from torch import Tensor
33
from torch.nn import functional as F
44

5-
from .._utils.compute_gramian import PSDMatrix
5+
from torchjd._utils.compute_gramian import PSDMatrix
6+
67
from ._aggregator_bases import GramianWeightedAggregator
78
from ._weighting_bases import Weighting
89

src/torchjd/aggregation/_mean.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
from torch import Tensor
33

4-
from .._utils.compute_gramian import Matrix
4+
from torchjd._utils.compute_gramian import Matrix
5+
56
from ._aggregator_bases import WeightedAggregator
67
from ._weighting_bases import Weighting
78

src/torchjd/aggregation/_mgda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
from torch import Tensor
33

4-
from .._utils.compute_gramian import PSDMatrix
4+
from torchjd._utils.compute_gramian import PSDMatrix
5+
56
from ._aggregator_bases import GramianWeightedAggregator
67
from ._weighting_bases import Weighting
78

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626
# mypy: ignore-errors
2727

28-
from .._utils.compute_gramian import Matrix
28+
from torchjd._utils.compute_gramian import Matrix
29+
2930
from ._utils.check_dependencies import check_dependencies_are_installed
3031
from ._weighting_bases import Weighting
3132

0 commit comments

Comments
 (0)