Skip to content

Commit 994932b

Browse files
committed
Rename PSDGeneralizedMatrix to PSDTensor
1 parent 7da352c commit 994932b

File tree

9 files changed

+31
-40
lines changed

9 files changed

+31
-40
lines changed

src/torchjd/_linalg/__init__.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
from ._gramian import compute_gramian, normalize, regularize
2-
from ._matrix import (
3-
Matrix,
4-
PSDGeneralizedMatrix,
5-
PSDMatrix,
6-
is_matrix,
7-
is_psd_generalized_matrix,
8-
is_psd_matrix,
9-
)
2+
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
103

114
__all__ = [
125
"compute_gramian",
136
"normalize",
147
"regularize",
158
"Matrix",
169
"PSDMatrix",
17-
"PSDGeneralizedMatrix",
10+
"PSDTensor",
1811
"is_matrix",
1912
"is_psd_matrix",
20-
"is_psd_generalized_matrix",
13+
"is_psd_tensor",
2114
]

src/torchjd/_linalg/_gramian.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from ._matrix import PSDGeneralizedMatrix, PSDMatrix
6+
from ._matrix import PSDMatrix, PSDTensor
77

88

99
@overload
@@ -17,11 +17,11 @@ def compute_gramian(matrix: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
1717

1818

1919
@overload
20-
def compute_gramian(matrix: Tensor, contracted_dims: int) -> PSDGeneralizedMatrix:
20+
def compute_gramian(matrix: Tensor, contracted_dims: int) -> PSDTensor:
2121
pass
2222

2323

24-
def compute_gramian(matrix: Tensor, contracted_dims: int = -1) -> PSDGeneralizedMatrix:
24+
def compute_gramian(matrix: Tensor, contracted_dims: int = -1) -> PSDTensor:
2525
"""
2626
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.
2727
@@ -35,7 +35,7 @@ def compute_gramian(matrix: Tensor, contracted_dims: int = -1) -> PSDGeneralized
3535
indices_dest = list(range(matrix.ndim - 1, contracted_dims - 1, -1))
3636
transposed_matrix = matrix.movedim(indices_source, indices_dest)
3737
gramian = torch.tensordot(matrix, transposed_matrix, dims=contracted_dims)
38-
return cast(PSDGeneralizedMatrix, gramian)
38+
return cast(PSDTensor, gramian)
3939

4040

4141
def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:

src/torchjd/_linalg/_matrix.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,30 @@ class Matrix(Tensor):
77
"""Tensor with exactly 2 dimensions."""
88

99

10-
class PSDGeneralizedMatrix(Tensor):
10+
class PSDTensor(Tensor):
1111
"""
1212
Tensor representing a quadratic form. The first half of its dimensions matches the reversed
1313
second half of its dimensions (e.g. shape=[4, 3, 3, 4]), and its reshaping into a matrix should
1414
be positive semi-definite.
1515
"""
1616

1717

18-
class PSDMatrix(PSDGeneralizedMatrix, Matrix):
18+
class PSDMatrix(PSDTensor, Matrix):
1919
"""Positive semi-definite matrix."""
2020

2121

2222
def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
2323
return t.ndim == 2
2424

2525

26-
def is_psd_generalized_matrix(t: Tensor) -> TypeGuard[PSDGeneralizedMatrix]:
26+
def is_psd_tensor(t: Tensor) -> TypeGuard[PSDTensor]:
2727
half_dim = t.ndim // 2
2828
return t.ndim % 2 == 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1]
2929
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
30-
# every function that uses this TypeGuard by using `assert_psd_generalized_matrix`.
30+
# every function that uses this TypeGuard by using `assert_is_psd_tensor`.
3131

3232

3333
def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]:
3434
return t.ndim == 2 and t.shape[0] == t.shape[1]
3535
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
36-
# every function that uses this TypeGuard, by using `assert_psd_matrix`.
36+
# every function that uses this TypeGuard, by using `assert_is_psd_matrix`.

src/torchjd/aggregation/_flattening.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch import Tensor
22

3-
from torchjd._linalg import PSDGeneralizedMatrix
3+
from torchjd._linalg import PSDTensor
44
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
55
from torchjd.autogram._gramian_utils import flatten
66

@@ -24,7 +24,7 @@ def __init__(self, weighting: Weighting):
2424
super().__init__()
2525
self.weighting = weighting
2626

27-
def forward(self, generalized_gramian: PSDGeneralizedMatrix) -> Tensor:
27+
def forward(self, generalized_gramian: PSDTensor) -> Tensor:
2828
k = generalized_gramian.ndim // 2
2929
shape = generalized_gramian.shape[:k]
3030
square_gramian = flatten(generalized_gramian)

src/torchjd/aggregation/_weighting_bases.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from torch import Tensor, nn
88

9-
from torchjd._linalg import PSDGeneralizedMatrix, is_psd_generalized_matrix
9+
from torchjd._linalg import PSDTensor, is_psd_tensor
1010

1111
_T = TypeVar("_T", contravariant=True)
1212
_FnInputT = TypeVar("_FnInputT")
@@ -65,7 +65,7 @@ def __init__(self):
6565
super().__init__()
6666

6767
@abstractmethod
68-
def forward(self, generalized_gramian: PSDGeneralizedMatrix) -> Tensor:
68+
def forward(self, generalized_gramian: PSDTensor) -> Tensor:
6969
"""Computes the vector of weights from the input generalized Gramian."""
7070

7171
def __call__(self, generalized_gramian: Tensor) -> Tensor:
@@ -74,5 +74,5 @@ def __call__(self, generalized_gramian: Tensor) -> Tensor:
7474
hooks.
7575
"""
7676

77-
assert is_psd_generalized_matrix(generalized_gramian)
77+
assert is_psd_tensor(generalized_gramian)
7878
return super().__call__(generalized_gramian)

src/torchjd/autogram/_gramian_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from torch import Tensor
55

6-
from torchjd._linalg import PSDGeneralizedMatrix, PSDMatrix
6+
from torchjd._linalg import PSDMatrix, PSDTensor
77

88

9-
def flatten(gramian: PSDGeneralizedMatrix) -> PSDMatrix:
9+
def flatten(gramian: PSDTensor) -> PSDMatrix:
1010
"""
1111
Flattens a generalized Gramian into a square matrix. The first half of the dimensions are
1212
flattened into the first dimension, and the second half are flattened into the second.
@@ -24,7 +24,7 @@ def flatten(gramian: PSDGeneralizedMatrix) -> PSDMatrix:
2424
return cast(PSDMatrix, square_gramian)
2525

2626

27-
def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGeneralizedMatrix:
27+
def reshape(gramian: PSDTensor, half_shape: list[int]) -> PSDTensor:
2828
"""
2929
Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions
3030
must be done from the left, while the reshape of the second half must be done from the right.
@@ -42,7 +42,7 @@ def reshape(gramian: PSDGeneralizedMatrix, half_shape: list[int]) -> PSDGenerali
4242
# [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4]
4343

4444
result = _revert_last_dims(_revert_last_dims(gramian).reshape(half_shape + half_shape))
45-
return cast(PSDGeneralizedMatrix, result)
45+
return cast(PSDTensor, result)
4646

4747

4848
def _revert_last_dims(t: Tensor) -> Tensor:
@@ -53,9 +53,7 @@ def _revert_last_dims(t: Tensor) -> Tensor:
5353
return t.movedim(last_dims, last_dims[::-1])
5454

5555

56-
def movedim(
57-
gramian: PSDGeneralizedMatrix, half_source: list[int], half_destination: list[int]
58-
) -> PSDGeneralizedMatrix:
56+
def movedim(gramian: PSDTensor, half_source: list[int], half_destination: list[int]) -> PSDTensor:
5957
"""
6058
Moves the dimensions of a Gramian from some source dimensions to destination dimensions. This
6159
must be done simultaneously on the first half of the dimensions and on the second half of the
@@ -86,4 +84,4 @@ def movedim(
8684
source = half_source_ + [last_dim - i for i in half_source_]
8785
destination = half_destination_ + [last_dim - i for i in half_destination_]
8886
moved_gramian = gramian.movedim(source, destination)
89-
return cast(PSDGeneralizedMatrix, moved_gramian)
87+
return cast(PSDTensor, moved_gramian)

tests/unit/autogram/test_gramian_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pytest import mark
22
from torch.testing import assert_close
3-
from utils.asserts import assert_is_psd_generalized_matrix, assert_is_psd_matrix
3+
from utils.asserts import assert_is_psd_matrix, assert_is_psd_tensor
44
from utils.tensors import randn_
55

66
from torchjd._linalg import compute_gramian, is_psd_matrix
@@ -59,7 +59,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
5959
matrix = randn_(original_shape + [2])
6060
gramian = compute_gramian(matrix, 1)
6161
reshaped_gramian = reshape(gramian, target_shape)
62-
assert_is_psd_generalized_matrix(reshaped_gramian, atol=1e-04, rtol=0.0)
62+
assert_is_psd_tensor(reshaped_gramian, atol=1e-04, rtol=0.0)
6363

6464

6565
@mark.parametrize(
@@ -150,4 +150,4 @@ def test_movedim_yields_psd(shape: list[int], source: list[int], destination: li
150150
matrix = randn_(shape + [2])
151151
gramian = compute_gramian(matrix, 1)
152152
moveddim_gramian = movedim(gramian, source, destination)
153-
assert_is_psd_generalized_matrix(moveddim_gramian)
153+
assert_is_psd_tensor(moveddim_gramian)

tests/utils/asserts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import Tensor
33
from torch.testing import assert_close
44

5-
from torchjd._linalg import is_psd_generalized_matrix, is_psd_matrix
5+
from torchjd._linalg import is_psd_matrix, is_psd_tensor
66
from torchjd.autogram._gramian_utils import flatten
77
from torchjd.autojac._accumulation import is_tensor_with_jac
88

@@ -44,7 +44,7 @@ def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None:
4444
assert_close(eig_vals, expected_eig_vals, **kwargs)
4545

4646

47-
def assert_is_psd_generalized_matrix(t: Tensor, **kwargs) -> None:
48-
assert is_psd_generalized_matrix(t)
47+
def assert_is_psd_tensor(t: Tensor, **kwargs) -> None:
48+
assert is_psd_tensor(t)
4949
matrix = flatten(t)
5050
assert_is_psd_matrix(matrix, **kwargs)

tests/utils/forward_backwards.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from utils.architectures import get_in_out_shapes
99
from utils.contexts import fork_rng
1010

11-
from torchjd._linalg import PSDGeneralizedMatrix
11+
from torchjd._linalg import PSDTensor
1212
from torchjd.aggregation import Aggregator, Weighting
1313
from torchjd.autogram import Engine
1414
from torchjd.autojac import backward
@@ -117,7 +117,7 @@ def reshape_raw_losses(raw_losses: Tensor) -> Tensor:
117117

118118
def compute_gramian_with_autograd(
119119
output: Tensor, params: list[nn.Parameter], retain_graph: bool = False
120-
) -> PSDGeneralizedMatrix:
120+
) -> PSDTensor:
121121
"""
122122
Computes the Gramian of the Jacobian of the outputs with respect to the params using vmapped
123123
calls to the autograd engine.

0 commit comments

Comments
 (0)