Skip to content

Commit 26fc65a

Browse files
committed
Improve stationary matrices generation
1 parent b96a9a1 commit 26fc65a

1 file changed

Lines changed: 49 additions & 26 deletions

File tree

tests/unit/aggregation/_inputs.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,53 @@ def _generate_matrix(m: int, n: int, rank: int) -> Tensor:
1313
return A
1414

1515

16-
def _generate_strong_stationary_matrix(m: int, n: int) -> Tensor:
16+
def _generate_strong_stationary_matrix(m: int, n: int, rank: int) -> Tensor:
1717
"""
18-
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
19-
vector 0<v with v^T A = 0.
18+
Generates a random matrix A of shape [m, n] with provided rank, such that there exists a vector
19+
0<v with v^T A = 0.
2020
"""
2121

2222
v = torch.abs(torch.randn([m]))
23-
return _generate_matrix_orthogonal_to_vector(v, n)
23+
U1 = normalize(v, dim=0).unsqueeze(1)
24+
U2 = _generate_semi_orthonormal_complement(U1)
25+
U = torch.hstack([U1, U2])
26+
Vt = _generate_orthonormal_matrix(n)
27+
S = torch.diag(torch.abs(torch.randn([rank])))
28+
A = U[:, 1 : rank + 1] @ S @ Vt[1 : rank + 1, :]
29+
return A
2430

2531

26-
def _generate_weak_stationary_matrix(m: int, n: int) -> Tensor:
32+
def _generate_weak_non_strong_stationary_matrix(m: int, n: int, rank: int) -> Tensor:
2733
"""
28-
Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
29-
vector 0<=v with one coordinate equal to 0 and such that v^T A = 0.
30-
31-
Note that if multiple coordinates of v were equal to 0, the generated matrix would still be weak
32-
stationary, but here we only set one of them to 0 for simplicity.
34+
Generates a random matrix A of shape [m, n] with provided rank, such that there exists a vector
35+
0<=v with one coordinate equal to 0 and such that v^T A = 0, and there is no vector 0<w with
36+
w^T A = 0.
3337
"""
3438

3539
v = torch.abs(torch.randn([m]))
36-
i = torch.randint(0, m, [])
37-
v[i] = 0.0
38-
return _generate_matrix_orthogonal_to_vector(v, n)
40+
split_index = torch.randint(0, m, []).item()
41+
shuffled_range = torch.randperm(m)
42+
U1 = torch.zeros([m, 2])
43+
U1[shuffled_range[:split_index], 0] = normalize(v[shuffled_range[:split_index]], dim=0)
44+
U1[shuffled_range[split_index:], 1] = normalize(v[shuffled_range[split_index:]], dim=0)
45+
U2 = _generate_semi_orthonormal_complement(U1)
46+
U = torch.hstack([U1, U2])
47+
Vt = _generate_orthonormal_matrix(n)
48+
S = torch.diag(torch.abs(torch.randn([rank])))
49+
A = U[:, 1 : rank + 1] @ S @ Vt[1 : rank + 1, :]
50+
return A
3951

4052

41-
def _generate_matrix_orthogonal_to_vector(v: Tensor, n: int) -> Tensor:
53+
def _generate_non_weak_stationary_matrix(m: int, n: int, rank: int) -> Tensor:
4254
"""
43-
Generates a random matrix A of shape [len(v), n] with rank min(n, len(v) - 1) such that
44-
v^T A = 0.
55+
Generates a random matrix A of shape [m, n] with provided rank, such that there is no vector
56+
0<=w with w^T A = 0.
4557
"""
4658

47-
rank = min(n, len(v) - 1)
48-
Q = normalize(v, dim=0).unsqueeze(1)
49-
U = _generate_semi_orthonormal_complement(Q)
59+
v = torch.abs(torch.randn([m]))
60+
U1 = normalize(v, dim=0).unsqueeze(1)
61+
U2 = _generate_semi_orthonormal_complement(U1)
62+
U = torch.hstack([U1, U2])
5063
Vt = _generate_orthonormal_matrix(n)
5164
S = torch.diag(torch.abs(torch.randn([rank])))
5265
A = U[:, :rank] @ S @ Vt[:rank, :]
@@ -92,9 +105,11 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
92105
(9, 11),
93106
]
94107

95-
_stationary_matrices_shapes = [
96-
(5, 3),
97-
(9, 11),
108+
_stationary_matrices_triples = [
109+
(5, 3, 3),
110+
(9, 11, 6),
111+
(7, 13, 2),
112+
(3, 5, 3),
98113
]
99114

100115
_scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]
@@ -106,12 +121,20 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
106121
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
107122
zero_matrices = [torch.zeros([m, n]) for m, n in _zero_matrices_shapes]
108123
strong_stationary_matrices = [
109-
_generate_strong_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
124+
_generate_strong_stationary_matrix(m, n, rank) for m, n, rank in _stationary_matrices_triples
110125
]
111-
weak_stationary_matrices = [
112-
_generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
126+
weak_non_strong_stationary_matrices = [
127+
_generate_weak_non_strong_stationary_matrix(m, n, rank)
128+
for m, n, rank in _stationary_matrices_triples
113129
]
114-
typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices
130+
non_weak_stationary_matrices = [
131+
_generate_non_weak_stationary_matrix(m, n, rank) for m, n, rank in _stationary_matrices_triples
132+
]
133+
non_strong_stationary_matrices = weak_non_strong_stationary_matrices + non_weak_stationary_matrices
134+
typical_matrices = (
135+
zero_matrices + matrices + strong_stationary_matrices + non_strong_stationary_matrices
136+
)
137+
115138

116139
scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
117140
typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]

0 commit comments

Comments
 (0)