Skip to content

Commit 5eafa74

Browse files
committed
Improve typing of forward_backward.compute_gramian
1 parent f2d0d1b commit 5eafa74

2 files changed

Lines changed: 4 additions & 9 deletions

File tree

tests/unit/autogram/test_gramian_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from utils.forward_backwards import compute_gramian
55
from utils.tensors import randn_
66

7-
from torchjd._linalg import is_psd_generalized_matrix, is_psd_matrix
7+
from torchjd._linalg import is_psd_matrix
88
from torchjd.autogram._gramian_utils import flatten, movedim, reshape
99

1010

@@ -36,7 +36,6 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
3636
original_gramian = compute_gramian(original_matrix)
3737
target_gramian = compute_gramian(target_matrix)
3838

39-
assert is_psd_generalized_matrix(original_gramian)
4039
reshaped_gramian = reshape(original_gramian, target_shape)
4140

4241
assert_close(reshaped_gramian, target_gramian)
@@ -60,7 +59,6 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
6059
def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
6160
matrix = randn_(original_shape + [2])
6261
gramian = compute_gramian(matrix)
63-
assert is_psd_generalized_matrix(gramian)
6462
reshaped_gramian = reshape(gramian, target_shape)
6563
assert_psd_generalized_matrix(reshaped_gramian, atol=1e-04, rtol=0.0)
6664

@@ -78,7 +76,6 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
7876
def test_flatten_yields_matrix(shape: list[int]):
7977
matrix = randn_(shape + [2])
8078
gramian = compute_gramian(matrix)
81-
assert is_psd_generalized_matrix(gramian)
8279
flattened_gramian = flatten(gramian)
8380
assert is_psd_matrix(flattened_gramian)
8481

@@ -96,7 +93,6 @@ def test_flatten_yields_matrix(shape: list[int]):
9693
def test_flatten_yields_psd(shape: list[int]):
9794
matrix = randn_(shape + [2])
9895
gramian = compute_gramian(matrix)
99-
assert is_psd_generalized_matrix(gramian)
10096
flattened_gramian = flatten(gramian)
10197
assert_psd_matrix(flattened_gramian, atol=1e-04, rtol=0.0)
10298

@@ -128,7 +124,6 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
128124
original_gramian = compute_gramian(original_matrix)
129125
target_gramian = compute_gramian(target_matrix)
130126

131-
assert is_psd_generalized_matrix(original_gramian)
132127
moveddim_gramian = movedim(original_gramian, source, destination)
133128

134129
assert_close(moveddim_gramian, target_gramian)
@@ -155,6 +150,5 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
155150
def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]):
156151
matrix = randn_(shape + [2])
157152
gramian = compute_gramian(matrix)
158-
assert is_psd_generalized_matrix(gramian)
159153
moveddim_gramian = movedim(gramian, source, destination)
160154
assert_psd_generalized_matrix(moveddim_gramian)

tests/utils/forward_backwards.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from utils.architectures import get_in_out_shapes
99
from utils.contexts import fork_rng
1010

11+
from torchjd._linalg import PSDGeneralizedMatrix
1112
from torchjd.aggregation import Aggregator, Weighting
1213
from torchjd.autogram import Engine
1314
from torchjd.autojac import backward
@@ -116,7 +117,7 @@ def reshape_raw_losses(raw_losses: Tensor) -> Tensor:
116117

117118
def compute_gramian_with_autograd(
118119
output: Tensor, params: list[nn.Parameter], retain_graph: bool = False
119-
) -> Tensor:
120+
) -> PSDGeneralizedMatrix:
120121
"""
121122
Computes the Gramian of the Jacobian of the outputs with respect to the params using vmapped
122123
calls to the autograd engine.
@@ -141,7 +142,7 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]:
141142
return gramian
142143

143144

144-
def compute_gramian(matrix: Tensor) -> Tensor:
145+
def compute_gramian(matrix: Tensor) -> PSDGeneralizedMatrix:
145146
"""Contracts the last dimension of matrix to make it into a Gramian."""
146147

147148
indices = list(range(matrix.ndim))

0 commit comments

Comments
 (0)