33from torch .nn .functional import normalize
44
55
6- def _generate_matrix (m : int , n : int , rank : int ) -> Tensor :
7- """Generates a random matrix A of shape [m, n] with provided rank."""
6+ def _sample_matrix (m : int , n : int , rank : int ) -> Tensor :
7+ """Samples a random matrix A of shape [m, n] with provided rank."""
88
9- U = _generate_orthonormal_matrix (m )
10- Vt = _generate_orthonormal_matrix (n )
9+ U = _sample_orthonormal_matrix (m )
10+ Vt = _sample_orthonormal_matrix (n )
1111 S = torch .diag (torch .abs (torch .randn ([rank ])))
1212 A = U [:, :rank ] @ S @ Vt [:rank , :]
1313 return A
1414
1515
16- def _generate_strong_stationary_matrix (m : int , n : int ) -> Tensor :
16+ def _sample_strong_matrix (m : int , n : int , rank : int ) -> Tensor :
1717 """
18- Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
19- vector 0<v with v^T A = 0.
18+ Samples a random strongly stationary matrix A of shape [m, n] with provided rank.
19+
20+ Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
21+ that v^T A = 0.
22+
23+ This is done by sampling a positive v, and by then sampling a matrix orthogonal to v.
2024 """
2125
26+ assert 1 < m
27+ assert 0 < rank <= min (m - 1 , n )
28+
2229 v = torch .abs (torch .randn ([m ]))
23- return _generate_matrix_orthogonal_to_vector (v , n )
30+ U1 = normalize (v , dim = 0 ).unsqueeze (1 )
31+ U2 = _sample_semi_orthonormal_complement (U1 )
32+ Vt = _sample_orthonormal_matrix (n )
33+ S = torch .diag (torch .abs (torch .randn ([rank ])))
34+ A = U2 [:, :rank ] @ S @ Vt [:rank , :]
35+ return A
2436
2537
26- def _generate_weak_stationary_matrix (m : int , n : int ) -> Tensor :
38+ def _sample_strictly_weak_matrix (m : int , n : int , rank : int ) -> Tensor :
2739 """
28- Generates a random matrix A of shape [m, n] with rank min(n, m - 1), such that there exists a
29- vector 0<=v with one coordinate equal to 0 and such that v^T A = 0.
40+ Samples a random strictly weakly stationary matrix A of shape [m, n] with provided rank.
41+
42+ Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
43+ such that v^T A = 0.
3044
31- Note that if multiple coordinates of v were equal to 0, the generated matrix would still be weak
32- stationary, but here we only set one of them to 0 for simplicity.
45+ Definition: A matrix A is said to be strictly weakly stationary if it is weakly stationary and
46+ not strongly stationary, i.e. if there exists a vector 0 <= v, v != 0, such that v^T A = 0 and
47+ there exists no vector 0 < w with w^T A = 0.
48+
49+ This is done by sampling two unit-norm vectors v, v', whose sum u is a positive vector. These
50+ two vectors are also non-negative and non-zero, and are furthermore orthogonal. Then, a matrix
51+ A, orthogonal to v, is sampled. By its orthogonality to v, A is weakly stationary. Moreover,
52+ since v' is a non-negative left-singular vector of A with positive singular value s, any 0 < w
53+ satisfies w^T A != 0. Otherwise, we would have 0 = w^T A A^T v' = s w^T v' > 0, which is a
54+ contradiction. A is thus also not strongly stationary.
3355 """
3456
35- v = torch .abs (torch .randn ([m ]))
36- i = torch .randint (0 , m , [])
37- v [i ] = 0.0
38- return _generate_matrix_orthogonal_to_vector (v , n )
57+ assert 1 < m
58+ assert 0 < rank <= min (m - 1 , n )
59+
60+ u = torch .abs (torch .randn ([m ]))
61+ split_index = torch .randint (1 , m , []).item ()
62+ shuffled_range = torch .randperm (m )
63+ v = torch .zeros (m )
64+ v [shuffled_range [:split_index ]] = normalize (u [shuffled_range [:split_index ]], dim = 0 )
65+ v_prime = torch .zeros (m )
66+ v_prime [shuffled_range [split_index :]] = normalize (u [shuffled_range [split_index :]], dim = 0 )
67+ U1 = torch .stack ([v , v_prime ]).T
68+ U2 = _sample_semi_orthonormal_complement (U1 )
69+ U = torch .hstack ([U1 , U2 ])
70+ Vt = _sample_orthonormal_matrix (n )
71+ S = torch .diag (torch .abs (torch .randn ([rank ])))
72+ A = U [:, 1 : rank + 1 ] @ S @ Vt [:rank , :]
73+ return A
3974
4075
41- def _generate_matrix_orthogonal_to_vector ( v : Tensor , n : int ) -> Tensor :
76+ def _sample_non_weak_matrix ( m : int , n : int , rank : int ) -> Tensor :
4277 """
43- Generates a random matrix A of shape [len(v), n] with rank min(n, len(v) - 1) such that
44- v^T A = 0.
78+ Samples a random non weakly-stationary matrix A of shape [m, n] with provided rank.
79+
80+ This is done by sampling a positive u, and by then sampling a matrix A that has u as one of its
81+ left-singular vectors, with positive singular value s. Any 0 <= v, v != 0, satisfies v^T A != 0.
82+ Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a contradiction. A is thus not
83+ weakly stationary.
4584 """
4685
47- rank = min (n , len (v ) - 1 )
48- Q = normalize (v , dim = 0 ).unsqueeze (1 )
49- U = _generate_semi_orthonormal_complement (Q )
50- Vt = _generate_orthonormal_matrix (n )
86+ assert 0 < rank <= min (m , n )
87+
88+ u = torch .abs (torch .randn ([m ]))
89+ U1 = normalize (u , dim = 0 ).unsqueeze (1 )
90+ U2 = _sample_semi_orthonormal_complement (U1 )
91+ U = torch .hstack ([U1 , U2 ])
92+ Vt = _sample_orthonormal_matrix (n )
5193 S = torch .diag (torch .abs (torch .randn ([rank ])))
5294 A = U [:, :rank ] @ S @ Vt [:rank , :]
5395 return A
5496
5597
56- def _generate_orthonormal_matrix (dim : int ) -> Tensor :
57- """Uniformly generates a random orthonormal matrix of shape [dim, dim]."""
98+ def _sample_orthonormal_matrix (dim : int ) -> Tensor :
99+ """Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
58100
59- return _generate_semi_orthonormal_complement (torch .zeros ([dim , 0 ]))
101+ return _sample_semi_orthonormal_complement (torch .zeros ([dim , 0 ]))
60102
61103
62- def _generate_semi_orthonormal_complement (Q : Tensor ) -> Tensor :
104+ def _sample_semi_orthonormal_complement (Q : Tensor ) -> Tensor :
63105 """
64- Uniformly generates a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
106+ Uniformly samples a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
65107 orthogonal to Q, i.e. such that the concatenation [Q, Q'] is an orthonormal matrix.
66108
67109 :param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
@@ -77,7 +119,7 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
77119 return Q_prime
78120
79121
80- _matrix_dimension_triples = [
122+ _normal_dims = [
81123 (1 , 1 , 1 ),
82124 (4 , 3 , 1 ),
83125 (4 , 3 , 2 ),
@@ -86,32 +128,35 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
86128 (9 , 11 , 9 ),
87129]
88130
89- _zero_matrices_shapes = [
90- (1 , 1 ),
91- (4 , 3 ),
92- (9 , 11 ),
131+ _zero_dims = [
132+ (1 , 1 , 0 ),
133+ (4 , 3 , 0 ),
134+ (9 , 11 , 0 ),
93135]
94136
95- _stationary_matrices_shapes = [
96- (5 , 3 ),
97- (9 , 11 ),
137+ _stationarity_dims = [
138+ (20 , 10 , 10 ),
139+ (20 , 10 , 5 ),
140+ (20 , 10 , 1 ),
141+ (20 , 100 , 1 ),
142+ (20 , 100 , 19 ),
98143]
99144
100145_scales = [0.0 , 1e-10 , 1e3 , 1e5 , 1e10 , 1e15 ]
101146
102- # Fix seed to fix randomness of matrix generation
147+ # Fix seed to fix randomness of matrix sampling
103148torch .manual_seed (0 )
104149
105- matrices = [_generate_matrix (m , n , rank ) for m , n , rank in _matrix_dimension_triples ]
150+ matrices = [_sample_matrix (m , n , r ) for m , n , r in _normal_dims ]
151+ zero_matrices = [torch .zeros ([m , n ]) for m , n , _ in _zero_dims ]
152+ strong_matrices = [_sample_strong_matrix (m , n , r ) for m , n , r in _stationarity_dims ]
153+ strictly_weak_matrices = [_sample_strictly_weak_matrix (m , n , r ) for m , n , r in _stationarity_dims ]
154+ non_weak_matrices = [_sample_non_weak_matrix (m , n , r ) for m , n , r in _stationarity_dims ]
155+
106156scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices ]
107- zero_matrices = [torch .zeros ([m , n ]) for m , n in _zero_matrices_shapes ]
108- strong_stationary_matrices = [
109- _generate_strong_stationary_matrix (m , n ) for m , n in _stationary_matrices_shapes
110- ]
111- weak_stationary_matrices = [
112- _generate_weak_stationary_matrix (m , n ) for m , n in _stationary_matrices_shapes
113- ]
114- typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices
157+
158+ non_strong_matrices = strictly_weak_matrices + non_weak_matrices
159+ typical_matrices = zero_matrices + matrices + strong_matrices + non_strong_matrices
115160
116161scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix .shape [0 ] >= 2 ]
117162typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix .shape [0 ] >= 2 ]
0 commit comments