Skip to content

Commit a793693

Browse files
committed
Can parametrize number of dimensions to contract in compute_gramian
1 parent 57af9f1 commit a793693

File tree

5 files changed

+41
-29
lines changed

5 files changed

+41
-29
lines changed

src/torchjd/_linalg/_gramian.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,40 @@
1-
from typing import cast
1+
from typing import Literal, cast, overload
22

33
import torch
44

5-
from ._matrix import GeneralizedMatrix, PSDMatrix
5+
from ._matrix import GeneralizedMatrix, PSDGeneralizedMatrix, PSDMatrix
66

77

8+
@overload
89
def compute_gramian(matrix: GeneralizedMatrix) -> PSDMatrix:
10+
pass
11+
12+
13+
@overload
14+
def compute_gramian(matrix: GeneralizedMatrix, contracted_dims: Literal[-1]) -> PSDMatrix:
15+
pass
16+
17+
18+
@overload
19+
def compute_gramian(matrix: GeneralizedMatrix, contracted_dims: int) -> PSDGeneralizedMatrix:
20+
pass
21+
22+
23+
def compute_gramian(matrix: GeneralizedMatrix, contracted_dims: int = -1) -> PSDGeneralizedMatrix:
924
"""
10-
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
25+
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.
26+
27+
`contracted_dims` specifies the number of trailing dimensions to contract. If negative,
28+
it indicates the number of leading dimensions to preserve (e.g., ``-1`` preserves the
29+
first dimension).
1130
"""
1231

13-
indices = list(range(1, matrix.ndim))
14-
gramian = torch.tensordot(matrix, matrix, dims=(indices, indices))
15-
return cast(PSDMatrix, gramian)
32+
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + matrix.ndim
33+
indices_source = list(range(matrix.ndim - contracted_dims))
34+
indices_dest = list(range(matrix.ndim - 1, contracted_dims - 1, -1))
35+
transposed_matrix = matrix.movedim(indices_source, indices_dest)
36+
gramian = torch.tensordot(matrix, transposed_matrix, dims=contracted_dims)
37+
return cast(PSDGeneralizedMatrix, gramian)
1638

1739

1840
def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:

tests/unit/autogram/test_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
CloneParams,
6969
autograd_forward_backward,
7070
autogram_forward_backward,
71-
compute_gramian,
7271
compute_gramian_with_autograd,
7372
forward_pass,
7473
make_mse_loss_fn,
@@ -79,6 +78,7 @@
7978
)
8079
from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_
8180

81+
from torchjd._linalg import compute_gramian
8282
from torchjd.aggregation import UPGradWeighting
8383
from torchjd.autogram._engine import Engine
8484
from torchjd.autogram._gramian_utils import movedim, reshape
@@ -418,9 +418,9 @@ def test_compute_gramian_manual():
418418
weight_jacobian = zeros_([out_dims, model.weight.numel()])
419419
for j in range(out_dims):
420420
weight_jacobian[j, j * in_dims : (j + 1) * in_dims] = input
421-
weight_gramian = compute_gramian(weight_jacobian)
421+
weight_gramian = compute_gramian(weight_jacobian, 1)
422422
bias_jacobian = torch.diag(ones_(out_dims))
423-
bias_gramian = compute_gramian(bias_jacobian)
423+
bias_gramian = compute_gramian(bias_jacobian, 1)
424424
expected_gramian = weight_gramian + bias_gramian
425425

426426
assert_close(gramian, expected_gramian)

tests/unit/autogram/test_gramian_utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from pytest import mark
22
from torch.testing import assert_close
33
from utils.asserts import assert_is_psd_generalized_matrix, assert_is_psd_matrix
4-
from utils.forward_backwards import compute_gramian
54
from utils.tensors import randn_
65

7-
from torchjd._linalg import is_psd_matrix
6+
from torchjd._linalg import compute_gramian, is_psd_matrix
87
from torchjd.autogram._gramian_utils import flatten, movedim, reshape
98

109

@@ -33,8 +32,8 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
3332
original_matrix = randn_(original_shape + [2])
3433
target_matrix = original_matrix.reshape(target_shape + [2])
3534

36-
original_gramian = compute_gramian(original_matrix)
37-
target_gramian = compute_gramian(target_matrix)
35+
original_gramian = compute_gramian(original_matrix, 1)
36+
target_gramian = compute_gramian(target_matrix, 1)
3837

3938
reshaped_gramian = reshape(original_gramian, target_shape)
4039

@@ -58,7 +57,7 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
5857
)
5958
def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
6059
matrix = randn_(original_shape + [2])
61-
gramian = compute_gramian(matrix)
60+
gramian = compute_gramian(matrix, 1)
6261
reshaped_gramian = reshape(gramian, target_shape)
6362
assert_is_psd_generalized_matrix(reshaped_gramian, atol=1e-04, rtol=0.0)
6463

@@ -75,7 +74,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
7574
)
7675
def test_flatten_yields_matrix(shape: list[int]):
7776
matrix = randn_(shape + [2])
78-
gramian = compute_gramian(matrix)
77+
gramian = compute_gramian(matrix, 1)
7978
flattened_gramian = flatten(gramian)
8079
assert is_psd_matrix(flattened_gramian)
8180

@@ -92,7 +91,7 @@ def test_flatten_yields_matrix(shape: list[int]):
9291
)
9392
def test_flatten_yields_psd(shape: list[int]):
9493
matrix = randn_(shape + [2])
95-
gramian = compute_gramian(matrix)
94+
gramian = compute_gramian(matrix, 1)
9695
flattened_gramian = flatten(gramian)
9796
assert_is_psd_matrix(flattened_gramian, atol=1e-04, rtol=0.0)
9897

@@ -121,8 +120,8 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
121120
original_matrix = randn_(shape + [2])
122121
target_matrix = original_matrix.movedim(source, destination)
123122

124-
original_gramian = compute_gramian(original_matrix)
125-
target_gramian = compute_gramian(target_matrix)
123+
original_gramian = compute_gramian(original_matrix, 1)
124+
target_gramian = compute_gramian(target_matrix, 1)
126125

127126
moveddim_gramian = movedim(original_gramian, source, destination)
128127

@@ -149,6 +148,6 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
149148
)
150149
def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]):
151150
matrix = randn_(shape + [2])
152-
gramian = compute_gramian(matrix)
151+
gramian = compute_gramian(matrix, 1)
153152
moveddim_gramian = movedim(gramian, source, destination)
154153
assert_is_psd_generalized_matrix(moveddim_gramian)

tests/unit/linalg/test_gramian.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from utils.asserts import assert_is_psd_matrix
33
from utils.tensors import randn_
44

5-
from torchjd._linalg import compute_gramian, is_generalized_matrix, is_matrix
6-
from torchjd._linalg._gramian import normalize, regularize
5+
from torchjd._linalg import compute_gramian, is_generalized_matrix, is_matrix, normalize, regularize
76

87

98
@mark.parametrize(

tests/utils/forward_backwards.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,6 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]:
142142
return gramian
143143

144144

145-
def compute_gramian(matrix: Tensor) -> PSDGeneralizedMatrix:
146-
"""Contracts the last dimension of matrix to make it into a Gramian."""
147-
148-
indices = list(range(matrix.ndim))
149-
transposed_matrix = matrix.movedim(indices, indices[::-1])
150-
return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0]))
151-
152-
153145
class CloneParams:
154146
"""
155147
ContextManager enabling the computation of per-usage gradients.

0 commit comments

Comments
 (0)