Skip to content

Commit 07867c0

Browse files
committed
factorise computing gramians
1 parent 179bdfd commit 07867c0

File tree

7 files changed

+24
-16
lines changed

7 files changed

+24
-16
lines changed

src/torchjd/_utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .compute_gramian import compute_gramian
2+
3+
__all__ = ["compute_gramian"]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def compute_gramian(generalized_matrix: Tensor) -> Tensor:
6+
"""
7+
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given
8+
generalized matrix. Specifically, this is equivalent to
9+
10+
matrix = generalized_matrix.reshape([generalized_matrix.shape[0], -1])
11+
return matrix @ matrix.T
12+
"""
13+
dims = list(range(1, generalized_matrix.ndim))
14+
gramian = torch.tensordot(generalized_matrix, generalized_matrix, dims=(dims, dims))
15+
return gramian

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from torch import Tensor, nn
44

5-
from ._utils.gramian import compute_gramian
5+
from torchjd._utils import compute_gramian
6+
67
from ._weighting_bases import Matrix, PSDMatrix, Weighting
78

89

src/torchjd/aggregation/_utils/gramian.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,6 @@
22
from torch import Tensor
33

44

5-
def compute_gramian(matrix: Tensor) -> Tensor:
6-
"""
7-
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
8-
"""
9-
10-
return matrix @ matrix.T
11-
12-
135
def normalize(gramian: Tensor, eps: float) -> Tensor:
146
"""
157
Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.

src/torchjd/autogram/_gramian_computer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch import Tensor
55
from torch.utils._pytree import PyTree
66

7+
from torchjd._utils import compute_gramian
78
from torchjd.autogram._jacobian_computer import JacobianComputer
89

910

@@ -29,10 +30,6 @@ class JacobianBasedGramianComputer(GramianComputer, ABC):
2930
def __init__(self, jacobian_computer):
3031
self.jacobian_computer = jacobian_computer
3132

32-
@staticmethod
33-
def _to_gramian(jacobian: Tensor) -> Tensor:
34-
return jacobian @ jacobian.T
35-
3633

3734
class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
3835
"""
@@ -71,7 +68,7 @@ def __call__(
7168
self.remaining_counter -= 1
7269

7370
if self.remaining_counter == 0:
74-
gramian = self._to_gramian(self.summed_jacobian)
71+
gramian = compute_gramian(self.summed_jacobian)
7572
del self.summed_jacobian
7673
return gramian
7774
else:

tests/unit/aggregation/test_mgda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from torch.testing import assert_close
44
from utils.tensors import ones_, randn_
55

6+
from torchjd._utils.compute_gramian import compute_gramian
67
from torchjd.aggregation import MGDA
78
from torchjd.aggregation._mgda import MGDAWeighting
8-
from torchjd.aggregation._utils.gramian import compute_gramian
99

1010
from ._asserts import (
1111
assert_expected_structure,

tests/unit/aggregation/test_pcgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from torch.testing import assert_close
44
from utils.tensors import ones_, randn_
55

6+
from torchjd._utils.compute_gramian import compute_gramian
67
from torchjd.aggregation import PCGrad
78
from torchjd.aggregation._pcgrad import PCGradWeighting
89
from torchjd.aggregation._upgrad import UPGradWeighting
9-
from torchjd.aggregation._utils.gramian import compute_gramian
1010

1111
from ._asserts import assert_expected_structure, assert_non_differentiable
1212
from ._inputs import scaled_matrices, typical_matrices

0 commit comments

Comments
 (0)