11import torch
2+ from torch import Tensor
23from torch .testing import assert_close
34
4- from torchjd ._linalg import PSDGeneralizedMatrix , PSDMatrix
5+ from torchjd ._linalg import is_psd_generalized_matrix , is_psd_matrix
56from torchjd .autogram ._gramian_utils import flatten
67from torchjd .autojac ._accumulation import is_tensor_with_jac
78
89
9- def assert_has_jac (t : torch . Tensor ) -> None :
10+ def assert_has_jac (t : Tensor ) -> None :
1011 assert is_tensor_with_jac (t )
1112 assert t .jac is not None and t .jac .shape [1 :] == t .shape
1213
1314
14- def assert_has_no_jac (t : torch . Tensor ) -> None :
15+ def assert_has_no_jac (t : Tensor ) -> None :
1516 assert not is_tensor_with_jac (t )
1617
1718
18- def assert_jac_close (t : torch . Tensor , expected_jac : torch . Tensor , ** kwargs ) -> None :
19+ def assert_jac_close (t : Tensor , expected_jac : Tensor , ** kwargs ) -> None :
1920 assert is_tensor_with_jac (t )
2021 assert_close (t .jac , expected_jac , ** kwargs )
2122
2223
23- def assert_has_grad (t : torch . Tensor ) -> None :
24+ def assert_has_grad (t : Tensor ) -> None :
2425 assert (t .grad is not None ) and (t .shape == t .grad .shape )
2526
2627
27- def assert_has_no_grad (t : torch . Tensor ) -> None :
28+ def assert_has_no_grad (t : Tensor ) -> None :
2829 assert t .grad is None
2930
3031
31- def assert_grad_close (t : torch . Tensor , expected_grad : torch . Tensor , ** kwargs ) -> None :
32+ def assert_grad_close (t : Tensor , expected_grad : Tensor , ** kwargs ) -> None :
3233 assert t .grad is not None
3334 assert_close (t .grad , expected_grad , ** kwargs )
3435
3536
36- def assert_psd_matrix (matrix : PSDMatrix , ** kwargs ) -> None :
37+ def assert_is_psd_matrix (matrix : Tensor , ** kwargs ) -> None :
38+ assert is_psd_matrix (matrix )
3739 assert_close (matrix , matrix .mH , ** kwargs )
3840
3941 eig_vals = torch .linalg .eigvalsh (matrix )
@@ -42,6 +44,7 @@ def assert_psd_matrix(matrix: PSDMatrix, **kwargs) -> None:
4244 assert_close (eig_vals , expected_eig_vals , ** kwargs )
4345
4446
45- def assert_psd_generalized_matrix (t : PSDGeneralizedMatrix , ** kwargs ) -> None :
47+ def assert_is_psd_generalized_matrix (t : Tensor , ** kwargs ) -> None :
48+ assert is_psd_generalized_matrix (t )
4649 matrix = flatten (t )
47- assert_psd_matrix (matrix , ** kwargs )
50+ assert_is_psd_matrix (matrix , ** kwargs )
0 commit comments