11from pytest import mark
22from torch .testing import assert_close
33from utils .asserts import assert_is_psd_generalized_matrix , assert_is_psd_matrix
4- from utils .forward_backwards import compute_gramian
54from utils .tensors import randn_
65
7- from torchjd ._linalg import is_psd_matrix
6+ from torchjd ._linalg import compute_gramian , is_psd_matrix
87from torchjd .autogram ._gramian_utils import flatten , movedim , reshape
98
109
@@ -33,8 +32,8 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
3332 original_matrix = randn_ (original_shape + [2 ])
3433 target_matrix = original_matrix .reshape (target_shape + [2 ])
3534
36- original_gramian = compute_gramian (original_matrix )
37- target_gramian = compute_gramian (target_matrix )
35+ original_gramian = compute_gramian (original_matrix , 1 )
36+ target_gramian = compute_gramian (target_matrix , 1 )
3837
3938 reshaped_gramian = reshape (original_gramian , target_shape )
4039
@@ -58,7 +57,7 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
5857)
5958def test_reshape_yields_psd (original_shape : list [int ], target_shape : list [int ]):
6059 matrix = randn_ (original_shape + [2 ])
61- gramian = compute_gramian (matrix )
60+ gramian = compute_gramian (matrix , 1 )
6261 reshaped_gramian = reshape (gramian , target_shape )
6362 assert_is_psd_generalized_matrix (reshaped_gramian , atol = 1e-04 , rtol = 0.0 )
6463
@@ -75,7 +74,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
7574)
7675def test_flatten_yields_matrix (shape : list [int ]):
7776 matrix = randn_ (shape + [2 ])
78- gramian = compute_gramian (matrix )
77+ gramian = compute_gramian (matrix , 1 )
7978 flattened_gramian = flatten (gramian )
8079 assert is_psd_matrix (flattened_gramian )
8180
@@ -92,7 +91,7 @@ def test_flatten_yields_matrix(shape: list[int]):
9291)
9392def test_flatten_yields_psd (shape : list [int ]):
9493 matrix = randn_ (shape + [2 ])
95- gramian = compute_gramian (matrix )
94+ gramian = compute_gramian (matrix , 1 )
9695 flattened_gramian = flatten (gramian )
9796 assert_is_psd_matrix (flattened_gramian , atol = 1e-04 , rtol = 0.0 )
9897
@@ -121,8 +120,8 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
121120 original_matrix = randn_ (shape + [2 ])
122121 target_matrix = original_matrix .movedim (source , destination )
123122
124- original_gramian = compute_gramian (original_matrix )
125- target_gramian = compute_gramian (target_matrix )
123+ original_gramian = compute_gramian (original_matrix , 1 )
124+ target_gramian = compute_gramian (target_matrix , 1 )
126125
127126 moveddim_gramian = movedim (original_gramian , source , destination )
128127
@@ -149,6 +148,6 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
149148)
150149def test_movedim_yields_psd (shape : list [int ], source : list [int ], destination : list [int ]):
151150 matrix = randn_ (shape + [2 ])
152- gramian = compute_gramian (matrix )
151+ gramian = compute_gramian (matrix , 1 )
153152 moveddim_gramian = movedim (gramian , source , destination )
154153 assert_is_psd_generalized_matrix (moveddim_gramian )
0 commit comments