@@ -13,40 +13,53 @@ def _generate_matrix(m: int, n: int, rank: int) -> Tensor:
1313 return A
1414
1515
16- def _generate_strong_stationary_matrix (m : int , n : int ) -> Tensor :
16+ def _generate_strong_stationary_matrix (m : int , n : int , rank : int ) -> Tensor :
1717 """
18- Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
19- vector 0<v with v^T A = 0.
18+ Generates a random matrix A of shape [m, n] with provided rank, such that there exists a vector
19+ 0<v with v^T A = 0.
2020 """
2121
2222 v = torch .abs (torch .randn ([m ]))
23- return _generate_matrix_orthogonal_to_vector (v , n )
23+ U1 = normalize (v , dim = 0 ).unsqueeze (1 )
24+ U2 = _generate_semi_orthonormal_complement (U1 )
25+ U = torch .hstack ([U1 , U2 ])
26+ Vt = _generate_orthonormal_matrix (n )
27+ S = torch .diag (torch .abs (torch .randn ([rank ])))
28+ A = U [:, 1 : rank + 1 ] @ S @ Vt [1 : rank + 1 , :]
29+ return A
2430
2531
26- def _generate_weak_stationary_matrix (m : int , n : int ) -> Tensor :
32+ def _generate_weak_non_strong_stationary_matrix (m : int , n : int , rank : int ) -> Tensor :
2733 """
28- Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
29- vector 0<=v with one coordinate equal to 0 and such that v^T A = 0.
30-
31- Note that if multiple coordinates of v were equal to 0, the generated matrix would still be weak
32- stationary, but here we only set one of them to 0 for simplicity.
34+ Generates a random matrix A of shape [m, n] with provided rank, such that there exists a vector
35+ 0<=v with one coordinate equal to 0 and such that v^T A = 0, and there is no vector 0<w with
36+ w^T A = 0.
3337 """
3438
3539 v = torch .abs (torch .randn ([m ]))
36- i = torch .randint (0 , m , [])
37- v [i ] = 0.0
38- return _generate_matrix_orthogonal_to_vector (v , n )
40+ split_index = torch .randint (0 , m , []).item ()
41+ shuffled_range = torch .randperm (m )
42+ U1 = torch .zeros ([m , 2 ])
43+ U1 [shuffled_range [:split_index ], 0 ] = normalize (v [shuffled_range [:split_index ]], dim = 0 )
44+ U1 [shuffled_range [split_index :], 1 ] = normalize (v [shuffled_range [split_index :]], dim = 0 )
45+ U2 = _generate_semi_orthonormal_complement (U1 )
46+ U = torch .hstack ([U1 , U2 ])
47+ Vt = _generate_orthonormal_matrix (n )
48+ S = torch .diag (torch .abs (torch .randn ([rank ])))
49+ A = U [:, 1 : rank + 1 ] @ S @ Vt [1 : rank + 1 , :]
50+ return A
3951
4052
41- def _generate_matrix_orthogonal_to_vector ( v : Tensor , n : int ) -> Tensor :
53+ def _generate_non_weak_stationary_matrix ( m : int , n : int , rank : int ) -> Tensor :
4254 """
43- Generates a random matrix A of shape [len(v) , n] with rank min(n, len(v) - 1) such that
44- v ^T A = 0.
55+ Generates a random matrix A of shape [m , n] with provided rank, such that there is no vector
56+ 0<=w with w ^T A = 0.
4557 """
4658
47- rank = min (n , len (v ) - 1 )
48- Q = normalize (v , dim = 0 ).unsqueeze (1 )
49- U = _generate_semi_orthonormal_complement (Q )
59+ v = torch .abs (torch .randn ([m ]))
60+ U1 = normalize (v , dim = 0 ).unsqueeze (1 )
61+ U2 = _generate_semi_orthonormal_complement (U1 )
62+ U = torch .hstack ([U1 , U2 ])
5063 Vt = _generate_orthonormal_matrix (n )
5164 S = torch .diag (torch .abs (torch .randn ([rank ])))
5265 A = U [:, :rank ] @ S @ Vt [:rank , :]
@@ -92,9 +105,11 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
92105 (9 , 11 ),
93106]
94107
95- _stationary_matrices_shapes = [
96- (5 , 3 ),
97- (9 , 11 ),
108+ _stationary_matrices_triples = [
109+ (5 , 3 , 3 ),
110+ (9 , 11 , 6 ),
111+ (7 , 13 , 2 ),
112+ (3 , 5 , 3 ),
98113]
99114
100115_scales = [0.0 , 1e-10 , 1e3 , 1e5 , 1e10 , 1e15 ]
@@ -106,12 +121,20 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
106121scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices ]
107122zero_matrices = [torch .zeros ([m , n ]) for m , n in _zero_matrices_shapes ]
108123strong_stationary_matrices = [
109- _generate_strong_stationary_matrix (m , n ) for m , n in _stationary_matrices_shapes
124+ _generate_strong_stationary_matrix (m , n , rank ) for m , n , rank in _stationary_matrices_triples
110125]
111- weak_stationary_matrices = [
112- _generate_weak_stationary_matrix (m , n ) for m , n in _stationary_matrices_shapes
126+ weak_non_strong_stationary_matrices = [
127+ _generate_weak_non_strong_stationary_matrix (m , n , rank )
128+ for m , n , rank in _stationary_matrices_triples
113129]
114- typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices
130+ non_weak_stationary_matrices = [
131+ _generate_non_weak_stationary_matrix (m , n , rank ) for m , n , rank in _stationary_matrices_triples
132+ ]
133+ non_strong_stationary_matrices = weak_non_strong_stationary_matrices + non_weak_stationary_matrices
134+ typical_matrices = (
135+ zero_matrices + matrices + strong_stationary_matrices + non_strong_stationary_matrices
136+ )
137+
115138
116139scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix .shape [0 ] >= 2 ]
117140typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix .shape [0 ] >= 2 ]
0 commit comments