Skip to content

Commit 9ee68f8

Browse files
committed
several things:
- moves Matrix and PSDMatrix to compute_gramian (not best position probably, but should be in _utils) - Change return type of compute_gramian to PSDMatrix - Add compute_gramian_sum (note that the responsability of casting to PSDMatrix is given to _utils now). - add _gramian_based version of jac_to_grad. Note that we could put the tensordot(weights, jacobian, dims=1) in _utils as a weight_generalize_matrix method.
1 parent bede311 commit 9ee68f8

20 files changed

+65
-25
lines changed

src/torchjd/_utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .compute_gramian import compute_gramian
1+
from .compute_gramian import Matrix, PSDMatrix, compute_gramian, compute_gramian_sum
22

3-
__all__ = ["compute_gramian"]
3+
__all__ = ["compute_gramian", "compute_gramian_sum", "Matrix", "PSDMatrix"]
Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from typing import Annotated, cast
2+
13
import torch
24
from torch import Tensor
35

6+
Matrix = Annotated[Tensor, "ndim=2"]
7+
PSDMatrix = Annotated[Matrix, "Positive semi-definite"]
8+
49

5-
def compute_gramian(generalized_matrix: Tensor) -> Tensor:
10+
def compute_gramian(generalized_matrix: Tensor) -> PSDMatrix:
611
"""
712
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given
813
generalized matrix. Specifically, this is equivalent to
@@ -12,4 +17,9 @@ def compute_gramian(generalized_matrix: Tensor) -> Tensor:
1217
"""
1318
dims = list(range(1, generalized_matrix.ndim))
1419
gramian = torch.tensordot(generalized_matrix, generalized_matrix, dims=(dims, dims))
15-
return gramian
20+
return cast(PSDMatrix, gramian)
21+
22+
23+
def compute_gramian_sum(generalized_matrices: list[Tensor]) -> PSDMatrix:
24+
gramian = sum([compute_gramian(matrix) for matrix in generalized_matrices])
25+
return cast(PSDMatrix, gramian)

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from torchjd._utils import compute_gramian
66

7-
from ._weighting_bases import Matrix, PSDMatrix, Weighting
7+
from .._utils.compute_gramian import Matrix, PSDMatrix
8+
from ._weighting_bases import Weighting
89

910

1011
class Aggregator(nn.Module, ABC):
@@ -80,3 +81,4 @@ class GramianWeightedAggregator(WeightedAggregator):
8081

8182
def __init__(self, weighting: Weighting[PSDMatrix]):
8283
super().__init__(weighting << compute_gramian)
84+
self.psd_weighting = weighting

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828
import torch
2929
from torch import Tensor
3030

31+
from .._utils.compute_gramian import PSDMatrix
3132
from ._aggregator_bases import GramianWeightedAggregator
3233
from ._mean import MeanWeighting
3334
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
34-
from ._weighting_bases import PSDMatrix, Weighting
35+
from ._weighting_bases import Weighting
3536

3637

3738
class AlignedMTL(GramianWeightedAggregator):

src/torchjd/aggregation/_cagrad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import cast
22

3+
from .._utils.compute_gramian import PSDMatrix
34
from ._utils.check_dependencies import check_dependencies_are_installed
4-
from ._weighting_bases import PSDMatrix, Weighting
5+
from ._weighting_bases import Weighting
56

67
check_dependencies_are_installed(["cvxpy", "clarabel"])
78

src/torchjd/aggregation/_constant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from torch import Tensor
22

3+
from .._utils.compute_gramian import Matrix
34
from ._aggregator_bases import WeightedAggregator
45
from ._utils.str import vector_to_str
5-
from ._weighting_bases import Matrix, Weighting
6+
from ._weighting_bases import Weighting
67

78

89
class Constant(WeightedAggregator):

src/torchjd/aggregation/_dualproj.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
from torch import Tensor
44

5+
from .._utils.compute_gramian import PSDMatrix
56
from ._aggregator_bases import GramianWeightedAggregator
67
from ._mean import MeanWeighting
78
from ._utils.dual_cone import project_weights
89
from ._utils.gramian import normalize, regularize
910
from ._utils.non_differentiable import raise_non_differentiable_error
1011
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
11-
from ._weighting_bases import PSDMatrix, Weighting
12+
from ._weighting_bases import Weighting
1213

1314

1415
class DualProj(GramianWeightedAggregator):

src/torchjd/aggregation/_flattening.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from torch import Tensor
44

5-
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting
5+
from torchjd._utils.compute_gramian import PSDMatrix
6+
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
67
from torchjd.autogram._gramian_utils import reshape_gramian
78

89

src/torchjd/aggregation/_imtl_g.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
from torch import Tensor
33

4+
from .._utils.compute_gramian import PSDMatrix
45
from ._aggregator_bases import GramianWeightedAggregator
56
from ._utils.non_differentiable import raise_non_differentiable_error
6-
from ._weighting_bases import PSDMatrix, Weighting
7+
from ._weighting_bases import Weighting
78

89

910
class IMTLG(GramianWeightedAggregator):

src/torchjd/aggregation/_krum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from torch import Tensor
33
from torch.nn import functional as F
44

5+
from .._utils.compute_gramian import PSDMatrix
56
from ._aggregator_bases import GramianWeightedAggregator
6-
from ._weighting_bases import PSDMatrix, Weighting
7+
from ._weighting_bases import Weighting
78

89

910
class Krum(GramianWeightedAggregator):

0 commit comments

Comments
 (0)