-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathasserts.py
More file actions
51 lines (31 loc) · 1.36 KB
/
asserts.py
File metadata and controls
51 lines (31 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Any
import torch
from torch import Tensor
from torch.testing import assert_close
from torchjd._linalg import flatten, is_psd_matrix, is_psd_tensor
from torchjd.autojac._accumulation import is_tensor_with_jac
def assert_has_jac(t: Tensor) -> None:
assert is_tensor_with_jac(t)
assert t.jac is not None and t.jac.shape[1:] == t.shape
def assert_has_no_jac(t: Tensor) -> None:
assert not is_tensor_with_jac(t)
def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs: Any) -> None:
assert is_tensor_with_jac(t)
assert_close(t.jac, expected_jac, **kwargs)
def assert_has_grad(t: Tensor) -> None:
assert (t.grad is not None) and (t.shape == t.grad.shape)
def assert_has_no_grad(t: Tensor) -> None:
assert t.grad is None
def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs: Any) -> None:
assert t.grad is not None
assert_close(t.grad, expected_grad, **kwargs)
def assert_is_psd_matrix(matrix: Tensor, **kwargs: Any) -> None:
assert is_psd_matrix(matrix)
assert_close(matrix, matrix.mH, **kwargs)
eig_vals = torch.linalg.eigvalsh(matrix)
expected_eig_vals = eig_vals.clamp(min=0.0)
assert_close(eig_vals, expected_eig_vals, **kwargs)
def assert_is_psd_tensor(t: Tensor, **kwargs: Any) -> None:
assert is_psd_tensor(t)
matrix = flatten(t)
assert_is_psd_matrix(matrix, **kwargs)