Skip to content

Commit d60e9fa

Browse files
committed
improve asserts
1 parent 5eafa74 commit d60e9fa

3 files changed

Lines changed: 21 additions & 18 deletions

File tree

tests/unit/autogram/test_gramian_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pytest import mark
22
from torch.testing import assert_close
3-
from utils.asserts import assert_psd_generalized_matrix, assert_psd_matrix
3+
from utils.asserts import assert_is_psd_generalized_matrix, assert_is_psd_matrix
44
from utils.forward_backwards import compute_gramian
55
from utils.tensors import randn_
66

@@ -60,7 +60,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]):
6060
matrix = randn_(original_shape + [2])
6161
gramian = compute_gramian(matrix)
6262
reshaped_gramian = reshape(gramian, target_shape)
63-
assert_psd_generalized_matrix(reshaped_gramian, atol=1e-04, rtol=0.0)
63+
assert_is_psd_generalized_matrix(reshaped_gramian, atol=1e-04, rtol=0.0)
6464

6565

6666
@mark.parametrize(
@@ -94,7 +94,7 @@ def test_flatten_yields_psd(shape: list[int]):
9494
matrix = randn_(shape + [2])
9595
gramian = compute_gramian(matrix)
9696
flattened_gramian = flatten(gramian)
97-
assert_psd_matrix(flattened_gramian, atol=1e-04, rtol=0.0)
97+
assert_is_psd_matrix(flattened_gramian, atol=1e-04, rtol=0.0)
9898

9999

100100
@mark.parametrize(
@@ -151,4 +151,4 @@ def test_movedim_yields_psd(shape: list[int], source: list[int], destination: li
151151
matrix = randn_(shape + [2])
152152
gramian = compute_gramian(matrix)
153153
moveddim_gramian = movedim(gramian, source, destination)
154-
assert_psd_generalized_matrix(moveddim_gramian)
154+
assert_is_psd_generalized_matrix(moveddim_gramian)

tests/unit/linalg/test_gramian.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pytest import mark
2-
from utils.asserts import assert_psd_matrix
2+
from utils.asserts import assert_is_psd_matrix
33
from utils.tensors import randn_
44

55
from torchjd._linalg import compute_gramian, is_generalized_matrix, is_matrix
@@ -24,7 +24,7 @@ def test_gramian_is_psd(shape: list[int]):
2424
matrix = randn_(shape)
2525
assert is_generalized_matrix(matrix)
2626
gramian = compute_gramian(matrix)
27-
assert_psd_matrix(gramian)
27+
assert_is_psd_matrix(gramian)
2828

2929

3030
@mark.parametrize(
@@ -42,7 +42,7 @@ def test_normalize_yields_psd(shape: list[int]):
4242
assert is_matrix(matrix)
4343
gramian = compute_gramian(matrix)
4444
normalized_gramian = normalize(gramian, 1e-05)
45-
assert_psd_matrix(normalized_gramian)
45+
assert_is_psd_matrix(normalized_gramian)
4646

4747

4848
@mark.parametrize(
@@ -60,4 +60,4 @@ def test_regularize_yields_psd(shape: list[int]):
6060
assert is_matrix(matrix)
6161
gramian = compute_gramian(matrix)
6262
normalized_gramian = regularize(gramian, 1e-05)
63-
assert_psd_matrix(normalized_gramian)
63+
assert_is_psd_matrix(normalized_gramian)

tests/utils/asserts.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
11
import torch
2+
from torch import Tensor
23
from torch.testing import assert_close
34

4-
from torchjd._linalg import PSDGeneralizedMatrix, PSDMatrix
5+
from torchjd._linalg import is_psd_generalized_matrix, is_psd_matrix
56
from torchjd.autogram._gramian_utils import flatten
67
from torchjd.autojac._accumulation import is_tensor_with_jac
78

89

9-
def assert_has_jac(t: torch.Tensor) -> None:
10+
def assert_has_jac(t: Tensor) -> None:
1011
assert is_tensor_with_jac(t)
1112
assert t.jac is not None and t.jac.shape[1:] == t.shape
1213

1314

14-
def assert_has_no_jac(t: torch.Tensor) -> None:
15+
def assert_has_no_jac(t: Tensor) -> None:
1516
assert not is_tensor_with_jac(t)
1617

1718

18-
def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor, **kwargs) -> None:
19+
def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs) -> None:
1920
assert is_tensor_with_jac(t)
2021
assert_close(t.jac, expected_jac, **kwargs)
2122

2223

23-
def assert_has_grad(t: torch.Tensor) -> None:
24+
def assert_has_grad(t: Tensor) -> None:
2425
assert (t.grad is not None) and (t.shape == t.grad.shape)
2526

2627

27-
def assert_has_no_grad(t: torch.Tensor) -> None:
28+
def assert_has_no_grad(t: Tensor) -> None:
2829
assert t.grad is None
2930

3031

31-
def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor, **kwargs) -> None:
32+
def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs) -> None:
3233
assert t.grad is not None
3334
assert_close(t.grad, expected_grad, **kwargs)
3435

3536

36-
def assert_psd_matrix(matrix: PSDMatrix, **kwargs) -> None:
37+
def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None:
38+
assert is_psd_matrix(matrix)
3739
assert_close(matrix, matrix.mH, **kwargs)
3840

3941
eig_vals = torch.linalg.eigvalsh(matrix)
@@ -42,6 +44,7 @@ def assert_psd_matrix(matrix: PSDMatrix, **kwargs) -> None:
4244
assert_close(eig_vals, expected_eig_vals, **kwargs)
4345

4446

45-
def assert_psd_generalized_matrix(t: PSDGeneralizedMatrix, **kwargs) -> None:
47+
def assert_is_psd_generalized_matrix(t: Tensor, **kwargs) -> None:
48+
assert is_psd_generalized_matrix(t)
4649
matrix = flatten(t)
47-
assert_psd_matrix(matrix, **kwargs)
50+
assert_is_psd_matrix(matrix, **kwargs)

0 commit comments

Comments
 (0)