Skip to content

Commit b280b10

Browse files
refactor(linalg): Create _linalg package (#520)
* Create `_linalg` package, factorize computation of Gramian and put resulting `compute_gramian` in `_linalg`. * Move Matrix and PSDMatrix to `_linalg/matrix.py` * Properly type gramians as `PSDMatrix`. * Add assert_psd_matrix * test that compute_gramian returns a psd_matrix
1 parent e56b515 commit b280b10

30 files changed

+133
-58
lines changed

src/torchjd/_linalg/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .gramian import compute_gramian
2+
from .matrix import Matrix, PSDMatrix
3+
4+
__all__ = ["compute_gramian", "Matrix", "PSDMatrix"]

src/torchjd/_linalg/gramian.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .matrix import Matrix, PSDMatrix
2+
3+
4+
def compute_gramian(matrix: Matrix) -> PSDMatrix:
5+
"""
6+
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
7+
"""
8+
9+
return matrix @ matrix.T

src/torchjd/_linalg/matrix.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from typing import Annotated
2+
3+
from torch import Tensor
4+
5+
Matrix = Annotated[Tensor, "ndim=2"]
6+
PSDMatrix = Annotated[Matrix, "Positive semi-definite"]

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from torch import Tensor, nn
44

5-
from ._utils.gramian import compute_gramian
6-
from ._weighting_bases import Matrix, PSDMatrix, Weighting
5+
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian
6+
7+
from ._weighting_bases import Weighting
78

89

910
class Aggregator(nn.Module, ABC):

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
import torch
2929
from torch import Tensor
3030

31+
from torchjd._linalg import PSDMatrix
32+
3133
from ._aggregator_bases import GramianWeightedAggregator
3234
from ._mean import MeanWeighting
3335
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
34-
from ._weighting_bases import PSDMatrix, Weighting
36+
from ._weighting_bases import Weighting
3537

3638

3739
class AlignedMTL(GramianWeightedAggregator):
@@ -73,7 +75,7 @@ def __init__(self, pref_vector: Tensor | None = None):
7375
self._pref_vector = pref_vector
7476
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
7577

76-
def forward(self, gramian: Tensor) -> Tensor:
78+
def forward(self, gramian: PSDMatrix) -> Tensor:
7779
w = self.weighting(gramian)
7880
B = self._compute_balance_transformation(gramian)
7981
alpha = B @ w

src/torchjd/aggregation/_cagrad.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import cast
22

3+
from torchjd._linalg import PSDMatrix
4+
35
from ._utils.check_dependencies import check_dependencies_are_installed
4-
from ._weighting_bases import PSDMatrix, Weighting
6+
from ._weighting_bases import Weighting
57

68
check_dependencies_are_installed(["cvxpy", "clarabel"])
79

@@ -73,7 +75,7 @@ def __init__(self, c: float, norm_eps: float = 0.0001):
7375
self.c = c
7476
self.norm_eps = norm_eps
7577

76-
def forward(self, gramian: Tensor) -> Tensor:
78+
def forward(self, gramian: PSDMatrix) -> Tensor:
7779
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))
7880

7981
reduced_matrix = U @ S.sqrt().diag()

src/torchjd/aggregation/_constant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from torch import Tensor
22

3+
from torchjd._linalg import Matrix
4+
35
from ._aggregator_bases import WeightedAggregator
46
from ._utils.str import vector_to_str
5-
from ._weighting_bases import Matrix, Weighting
7+
from ._weighting_bases import Weighting
68

79

810
class Constant(WeightedAggregator):

src/torchjd/aggregation/_dualproj.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from torch import Tensor
44

5+
from torchjd._linalg import PSDMatrix
6+
57
from ._aggregator_bases import GramianWeightedAggregator
68
from ._mean import MeanWeighting
79
from ._utils.dual_cone import project_weights
810
from ._utils.gramian import normalize, regularize
911
from ._utils.non_differentiable import raise_non_differentiable_error
1012
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
11-
from ._weighting_bases import PSDMatrix, Weighting
13+
from ._weighting_bases import Weighting
1214

1315

1416
class DualProj(GramianWeightedAggregator):
@@ -86,7 +88,7 @@ def __init__(
8688
self.reg_eps = reg_eps
8789
self.solver = solver
8890

89-
def forward(self, gramian: Tensor) -> Tensor:
91+
def forward(self, gramian: PSDMatrix) -> Tensor:
9092
u = self.weighting(gramian)
9193
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
9294
w = project_weights(u, G, self.solver)

src/torchjd/aggregation/_flattening.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from torch import Tensor
44

5-
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting
5+
from torchjd._linalg.matrix import PSDMatrix
6+
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
67
from torchjd.autogram._gramian_utils import reshape_gramian
78

89

src/torchjd/aggregation/_imtl_g.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import torch
22
from torch import Tensor
33

4+
from torchjd._linalg import PSDMatrix
5+
46
from ._aggregator_bases import GramianWeightedAggregator
57
from ._utils.non_differentiable import raise_non_differentiable_error
6-
from ._weighting_bases import PSDMatrix, Weighting
8+
from ._weighting_bases import Weighting
79

810

911
class IMTLG(GramianWeightedAggregator):
@@ -27,7 +29,7 @@ class IMTLGWeighting(Weighting[PSDMatrix]):
2729
:class:`~torchjd.aggregation.IMTLG`.
2830
"""
2931

30-
def forward(self, gramian: Tensor) -> Tensor:
32+
def forward(self, gramian: PSDMatrix) -> Tensor:
3133
d = torch.sqrt(torch.diagonal(gramian))
3234
v = torch.linalg.pinv(gramian) @ d
3335
v_sum = v.sum()

0 commit comments

Comments
 (0)