44from utils .forward_backwards import compute_gramian
55from utils .tensors import randn_
66
7- from torchjd ._linalg import is_psd_generalized_matrix , is_psd_matrix
7+ from torchjd ._linalg import is_psd_matrix
88from torchjd .autogram ._gramian_utils import flatten , movedim , reshape
99
1010
@@ -36,7 +36,6 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
3636 original_gramian = compute_gramian (original_matrix )
3737 target_gramian = compute_gramian (target_matrix )
3838
39- assert is_psd_generalized_matrix (original_gramian )
4039 reshaped_gramian = reshape (original_gramian , target_shape )
4140
4241 assert_close (reshaped_gramian , target_gramian )
@@ -60,7 +59,6 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]
6059def test_reshape_yields_psd (original_shape : list [int ], target_shape : list [int ]):
6160 matrix = randn_ (original_shape + [2 ])
6261 gramian = compute_gramian (matrix )
63- assert is_psd_generalized_matrix (gramian )
6462 reshaped_gramian = reshape (gramian , target_shape )
6563 assert_psd_generalized_matrix (reshaped_gramian , atol = 1e-04 , rtol = 0.0 )
6664
@@ -78,7 +76,6 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
7876def test_flatten_yields_matrix (shape : list [int ]):
7977 matrix = randn_ (shape + [2 ])
8078 gramian = compute_gramian (matrix )
81- assert is_psd_generalized_matrix (gramian )
8279 flattened_gramian = flatten (gramian )
8380 assert is_psd_matrix (flattened_gramian )
8481
@@ -96,7 +93,6 @@ def test_flatten_yields_matrix(shape: list[int]):
9693def test_flatten_yields_psd (shape : list [int ]):
9794 matrix = randn_ (shape + [2 ])
9895 gramian = compute_gramian (matrix )
99- assert is_psd_generalized_matrix (gramian )
10096 flattened_gramian = flatten (gramian )
10197 assert_psd_matrix (flattened_gramian , atol = 1e-04 , rtol = 0.0 )
10298
@@ -128,7 +124,6 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
128124 original_gramian = compute_gramian (original_matrix )
129125 target_gramian = compute_gramian (target_matrix )
130126
131- assert is_psd_generalized_matrix (original_gramian )
132127 moveddim_gramian = movedim (original_gramian , source , destination )
133128
134129 assert_close (moveddim_gramian , target_gramian )
@@ -155,6 +150,5 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
155150def test_movedim_yields_psd (shape : list [int ], source : list [int ], destination : list [int ]):
156151 matrix = randn_ (shape + [2 ])
157152 gramian = compute_gramian (matrix )
158- assert is_psd_generalized_matrix (gramian )
159153 moveddim_gramian = movedim (gramian , source , destination )
160154 assert_psd_generalized_matrix (moveddim_gramian )
0 commit comments