Skip to content

Commit a55afca

Browse files
test(linalg): Add more tests for compute_gramian (#541)
Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
1 parent 6eddc91 commit a55afca

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

tests/unit/linalg/test_gramian.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pytest import mark
2+
from torch.testing import assert_close
23
from utils.asserts import assert_is_psd_matrix
3-
from utils.tensors import randn_
4+
from utils.tensors import randn_, tensor_
45

56
from torchjd._linalg import compute_gramian, is_matrix, normalize, regularize
67

@@ -25,6 +26,59 @@ def test_gramian_is_psd(shape: list[int]):
2526
assert_is_psd_matrix(gramian)
2627

2728

29+
def test_compute_gramian_scalar_input_0():
30+
t = tensor_(5.0)
31+
gramian = compute_gramian(t, contracted_dims=0)
32+
expected = tensor_(25.0)
33+
34+
assert_close(gramian, expected)
35+
36+
37+
def test_compute_gramian_vector_input_0():
38+
t = tensor_([2.0, 3.0])
39+
gramian = compute_gramian(t, contracted_dims=0)
40+
expected = tensor_([[4.0, 6.0], [6.0, 9.0]])
41+
42+
assert_close(gramian, expected)
43+
44+
45+
def test_compute_gramian_vector_input_1():
46+
t = tensor_([2.0, 3.0])
47+
gramian = compute_gramian(t, contracted_dims=1)
48+
expected = tensor_(13.0)
49+
50+
assert_close(gramian, expected)
51+
52+
53+
def test_compute_gramian_matrix_input_0():
54+
t = tensor_([[1.0, 2.0], [3.0, 4.0]])
55+
gramian = compute_gramian(t, contracted_dims=0)
56+
expected = tensor_(
57+
[
58+
[[[1.0, 3.0], [2.0, 4.0]], [[2.0, 6.0], [4.0, 8.0]]],
59+
[[[3.0, 9.0], [6.0, 12.0]], [[4.0, 12.0], [8.0, 16.0]]],
60+
]
61+
)
62+
63+
assert_close(gramian, expected)
64+
65+
66+
def test_compute_gramian_matrix_input_1():
67+
t = tensor_([[1.0, 2.0], [3.0, 4.0]])
68+
gramian = compute_gramian(t, contracted_dims=1)
69+
expected = tensor_([[5.0, 11.0], [11.0, 25.0]])
70+
71+
assert_close(gramian, expected)
72+
73+
74+
def test_compute_gramian_matrix_input_2():
75+
t = tensor_([[1.0, 2.0], [3.0, 4.0]])
76+
gramian = compute_gramian(t, contracted_dims=2)
77+
expected = tensor_(30.0)
78+
79+
assert_close(gramian, expected)
80+
81+
2882
@mark.parametrize(
2983
"shape",
3084
[

0 commit comments

Comments
 (0)