Skip to content

Commit 2d7bf7f

Browse files
refactor: Improve PSD typing (#522)
* Add PSDTensor * Remove GeneralizedMatrix * Use classes for Matrix, PSDMatrix and PSDTensor instead of type annotations * Use casting in compute_gramian, normalize and regularize * Add typeguard functions is_matrix, is_psd_tensor and is_psd_matrix * Move normalize and regularize to _linalg * Move _check_is_matrix to `Aggregator.__call__` * Improve internal type hints * Rename reshape_gramian to reshape and movedim_gramian to movedim * Add _gramian_utils.flatten * Rename a few tests * Add some parametrizations to test_gramian_is_psd * Add test_reshape_yields_psd, test_flatten_yields_matrix, test_flatten_yields_psd, test_movedim_yields_psd, test_normalize_yields_psd and test_regularize_yields_psd * Add assert_is_psd_tensor --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com>
1 parent 711e778 commit 2d7bf7f

30 files changed

+383
-158
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,9 @@ full = [
107107

108108
[tool.pytest.ini_options]
109109
xfail_strict = true
110+
111+
[tool.coverage.report]
112+
exclude_lines = [
113+
"pragma: not covered",
114+
"@overload",
115+
]

src/torchjd/_linalg/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from .gramian import compute_gramian
2-
from .matrix import Matrix, PSDMatrix
1+
from ._gramian import compute_gramian, normalize, regularize
2+
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
33

4-
__all__ = ["compute_gramian", "Matrix", "PSDMatrix"]
4+
__all__ = [
5+
"compute_gramian",
6+
"normalize",
7+
"regularize",
8+
"Matrix",
9+
"PSDMatrix",
10+
"PSDTensor",
11+
"is_matrix",
12+
"is_psd_matrix",
13+
"is_psd_tensor",
14+
]

src/torchjd/_linalg/_gramian.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Literal, cast, overload
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from ._matrix import Matrix, PSDMatrix, PSDTensor
7+
8+
9+
@overload
10+
def compute_gramian(t: Tensor) -> PSDMatrix:
11+
pass
12+
13+
14+
@overload
15+
def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
16+
pass
17+
18+
19+
@overload
20+
def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix:
21+
pass
22+
23+
24+
def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
25+
"""
26+
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.
27+
28+
`contracted_dims` specifies the number of trailing dimensions to contract. If negative,
29+
it indicates the number of leading dimensions to preserve (e.g., ``-1`` preserves the
30+
first dimension).
31+
"""
32+
33+
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
34+
indices_source = list(range(t.ndim - contracted_dims))
35+
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
36+
transposed = t.movedim(indices_source, indices_dest)
37+
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
38+
return cast(PSDTensor, gramian)
39+
40+
41+
def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
42+
"""
43+
Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.
44+
45+
If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the
46+
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
47+
therefore `G` divided by the sum of its diagonal elements.
48+
"""
49+
squared_frobenius_norm = gramian.diagonal().sum()
50+
if squared_frobenius_norm < eps:
51+
output = torch.zeros_like(gramian)
52+
else:
53+
output = gramian / squared_frobenius_norm
54+
return cast(PSDMatrix, output)
55+
56+
57+
def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
58+
"""
59+
Adds a regularization term to the gramian to enforce positive definiteness.
60+
61+
Because of numerical errors, `gramian` might have slightly negative eigenvalue(s). Adding a
62+
regularization term which is a small proportion of the identity matrix ensures that the gramian
63+
is positive definite.
64+
"""
65+
66+
regularization_matrix = eps * torch.eye(
67+
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
68+
)
69+
output = gramian + regularization_matrix
70+
return cast(PSDMatrix, output)

src/torchjd/_linalg/_matrix.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import TypeGuard
2+
3+
from torch import Tensor
4+
5+
# Note: we're using classes and inherittance instead of NewType because it's possible to have
6+
# multiple inherittance but there is no type intersection. However, these classes should never be
7+
# instantiated: they're only used for static type checking.
8+
9+
10+
class Matrix(Tensor):
11+
"""Tensor with exactly 2 dimensions."""
12+
13+
14+
class PSDTensor(Tensor):
15+
"""
16+
Tensor representing a quadratic form. The first half of its dimensions matches the reversed
17+
second half of its dimensions (e.g. shape=[4, 3, 3, 4]), and its reshaping into a matrix should
18+
be positive semi-definite.
19+
"""
20+
21+
22+
class PSDMatrix(PSDTensor, Matrix):
23+
"""Positive semi-definite matrix."""
24+
25+
26+
def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
27+
return t.ndim == 2
28+
29+
30+
def is_psd_tensor(t: Tensor) -> TypeGuard[PSDTensor]:
31+
half_dim = t.ndim // 2
32+
return t.ndim % 2 == 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1]
33+
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
34+
# every function that uses this TypeGuard by using `assert_is_psd_tensor`.
35+
36+
37+
def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]:
38+
return t.ndim == 2 and t.shape[0] == t.shape[1]
39+
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
40+
# every function that uses this TypeGuard, by using `assert_is_psd_matrix`.

src/torchjd/_linalg/gramian.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/torchjd/_linalg/matrix.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch import Tensor, nn
44

5-
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian
5+
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix
66

77
from ._weighting_bases import Weighting
88

@@ -18,20 +18,19 @@ def __init__(self):
1818

1919
@staticmethod
2020
def _check_is_matrix(matrix: Tensor) -> None:
21-
if len(matrix.shape) != 2:
21+
if not is_matrix(matrix):
2222
raise ValueError(
2323
"Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
2424
f"{matrix.shape}`."
2525
)
2626

2727
@abstractmethod
28-
def forward(self, matrix: Tensor) -> Tensor:
28+
def forward(self, matrix: Matrix) -> Tensor:
2929
"""Computes the aggregation from the input matrix."""
3030

31-
# Override to make type hints and documentation more specific
3231
def __call__(self, matrix: Tensor) -> Tensor:
3332
"""Computes the aggregation from the input matrix and applies all registered hooks."""
34-
33+
Aggregator._check_is_matrix(matrix)
3534
return super().__call__(matrix)
3635

3736
def __repr__(self) -> str:
@@ -54,7 +53,7 @@ def __init__(self, weighting: Weighting[Matrix]):
5453
self.weighting = weighting
5554

5655
@staticmethod
57-
def combine(matrix: Tensor, weights: Tensor) -> Tensor:
56+
def combine(matrix: Matrix, weights: Tensor) -> Tensor:
5857
"""
5958
Aggregates a matrix by making a linear combination of its rows, using the provided vector of
6059
weights.
@@ -63,8 +62,7 @@ def combine(matrix: Tensor, weights: Tensor) -> Tensor:
6362
vector = weights @ matrix
6463
return vector
6564

66-
def forward(self, matrix: Tensor) -> Tensor:
67-
self._check_is_matrix(matrix)
65+
def forward(self, matrix: Matrix) -> Tensor:
6866
weights = self.weighting(matrix)
6967
vector = self.combine(matrix, weights)
7068
return vector

src/torchjd/aggregation/_cagrad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import torch
1313
from torch import Tensor
1414

15+
from torchjd._linalg import normalize
16+
1517
from ._aggregator_bases import GramianWeightedAggregator
16-
from ._utils.gramian import normalize
1718
from ._utils.non_differentiable import raise_non_differentiable_error
1819

1920

src/torchjd/aggregation/_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import torch
2929
from torch import Tensor
3030

31+
from torchjd._linalg import Matrix
32+
3133
from ._aggregator_bases import Aggregator
3234
from ._sum import SumWeighting
3335
from ._utils.non_differentiable import raise_non_differentiable_error
@@ -56,7 +58,7 @@ def __init__(self, pref_vector: Tensor | None = None):
5658
# This prevents computing gradients that can be very wrong.
5759
self.register_full_backward_pre_hook(raise_non_differentiable_error)
5860

59-
def forward(self, matrix: Tensor) -> Tensor:
61+
def forward(self, matrix: Matrix) -> Tensor:
6062
weights = self.weighting(matrix)
6163
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
6264
best_direction = torch.linalg.pinv(units) @ weights

src/torchjd/aggregation/_dualproj.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
from torch import Tensor
44

5-
from torchjd._linalg import PSDMatrix
5+
from torchjd._linalg import PSDMatrix, normalize, regularize
66

77
from ._aggregator_bases import GramianWeightedAggregator
88
from ._mean import MeanWeighting
99
from ._utils.dual_cone import project_weights
10-
from ._utils.gramian import normalize, regularize
1110
from ._utils.non_differentiable import raise_non_differentiable_error
1211
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
1312
from ._weighting_bases import Weighting

0 commit comments

Comments
 (0)