11from pytest import mark
2+ from torch .testing import assert_close
23from utils .asserts import assert_is_psd_matrix
3- from utils .tensors import randn_
4+ from utils .tensors import randn_ , tensor_
45
56from torchjd ._linalg import compute_gramian , is_matrix , normalize , regularize
67
@@ -25,6 +26,59 @@ def test_gramian_is_psd(shape: list[int]):
2526 assert_is_psd_matrix (gramian )
2627
2728
29+ def test_compute_gramian_scalar_input_0 ():
30+ t = tensor_ (5.0 )
31+ gramian = compute_gramian (t , contracted_dims = 0 )
32+ expected = tensor_ (25.0 )
33+
34+ assert_close (gramian , expected )
35+
36+
37+ def test_compute_gramian_vector_input_0 ():
38+ t = tensor_ ([2.0 , 3.0 ])
39+ gramian = compute_gramian (t , contracted_dims = 0 )
40+ expected = tensor_ ([[4.0 , 6.0 ], [6.0 , 9.0 ]])
41+
42+ assert_close (gramian , expected )
43+
44+
45+ def test_compute_gramian_vector_input_1 ():
46+ t = tensor_ ([2.0 , 3.0 ])
47+ gramian = compute_gramian (t , contracted_dims = 1 )
48+ expected = tensor_ (13.0 )
49+
50+ assert_close (gramian , expected )
51+
52+
53+ def test_compute_gramian_matrix_input_0 ():
54+ t = tensor_ ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
55+ gramian = compute_gramian (t , contracted_dims = 0 )
56+ expected = tensor_ (
57+ [
58+ [[[1.0 , 3.0 ], [2.0 , 4.0 ]], [[2.0 , 6.0 ], [4.0 , 8.0 ]]],
59+ [[[3.0 , 9.0 ], [6.0 , 12.0 ]], [[4.0 , 12.0 ], [8.0 , 16.0 ]]],
60+ ]
61+ )
62+
63+ assert_close (gramian , expected )
64+
65+
66+ def test_compute_gramian_matrix_input_1 ():
67+ t = tensor_ ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
68+ gramian = compute_gramian (t , contracted_dims = 1 )
69+ expected = tensor_ ([[5.0 , 11.0 ], [11.0 , 25.0 ]])
70+
71+ assert_close (gramian , expected )
72+
73+
74+ def test_compute_gramian_matrix_input_2 ():
75+ t = tensor_ ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
76+ gramian = compute_gramian (t , contracted_dims = 2 )
77+ expected = tensor_ (30.0 )
78+
79+ assert_close (gramian , expected )
80+
81+
2882@mark .parametrize (
2983 "shape" ,
3084 [
0 commit comments