11import torch
2- from torch import Tensor
3- from torch .nn .functional import normalize
4-
5-
6- def _sample_matrix (m : int , n : int , rank : int ) -> Tensor :
7- """Samples a random matrix A of shape [m, n] with provided rank."""
8-
9- U = _sample_orthonormal_matrix (m )
10- Vt = _sample_orthonormal_matrix (n )
11- S = torch .diag (torch .abs (torch .randn ([rank ])))
12- A = U [:, :rank ] @ S @ Vt [:rank , :]
13- return A
14-
15-
16- def _sample_strong_matrix (m : int , n : int , rank : int ) -> Tensor :
17- """
18- Samples a random strongly stationary matrix A of shape [m, n] with provided rank.
19-
20- Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
21- that v^T A = 0.
22-
23- This is done by sampling a positive v, and by then sampling a matrix orthogonal to v.
24- """
25-
26- assert 1 < m
27- assert 0 < rank <= min (m - 1 , n )
28-
29- v = torch .abs (torch .randn ([m ]))
30- U1 = normalize (v , dim = 0 ).unsqueeze (1 )
31- U2 = _sample_semi_orthonormal_complement (U1 )
32- Vt = _sample_orthonormal_matrix (n )
33- S = torch .diag (torch .abs (torch .randn ([rank ])))
34- A = U2 [:, :rank ] @ S @ Vt [:rank , :]
35- return A
36-
37-
38- def _sample_strictly_weak_matrix (m : int , n : int , rank : int ) -> Tensor :
39- """
40- Samples a random strictly weakly stationary matrix A of shape [m, n] with provided rank.
41-
42- Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
43- such that v^T A = 0.
44-
45- Definition: A matrix A is said to be strictly weakly stationary if it is weakly stationary and
46- not strongly stationary, i.e. if there exists a vector 0 <= v, v != 0, such that v^T A = 0 and
47- there exists no vector 0 < w with w^T A = 0.
48-
49- This is done by sampling two unit-norm vectors v, v', whose sum u is a positive vector. These
50- two vectors are also non-negative and non-zero, and are furthermore orthogonal. Then, a matrix
51- A, orthogonal to v, is sampled. By its orthogonality to v, A is weakly stationary. Moreover,
52- since v' is a non-negative left-singular vector of A with positive singular value s, any 0 < w
53- satisfies w^T A != 0. Otherwise, we would have 0 = w^T A A^T v' = s w^T v' > 0, which is a
54- contradiction. A is thus also not strongly stationary.
55- """
56-
57- assert 1 < m
58- assert 0 < rank <= min (m - 1 , n )
59-
60- u = torch .abs (torch .randn ([m ]))
61- split_index = torch .randint (1 , m , []).item ()
62- shuffled_range = torch .randperm (m )
63- v = torch .zeros (m )
64- v [shuffled_range [:split_index ]] = normalize (u [shuffled_range [:split_index ]], dim = 0 )
65- v_prime = torch .zeros (m )
66- v_prime [shuffled_range [split_index :]] = normalize (u [shuffled_range [split_index :]], dim = 0 )
67- U1 = torch .stack ([v , v_prime ]).T
68- U2 = _sample_semi_orthonormal_complement (U1 )
69- U = torch .hstack ([U1 , U2 ])
70- Vt = _sample_orthonormal_matrix (n )
71- S = torch .diag (torch .abs (torch .randn ([rank ])))
72- A = U [:, 1 : rank + 1 ] @ S @ Vt [:rank , :]
73- return A
74-
75-
76- def _sample_non_weak_matrix (m : int , n : int , rank : int ) -> Tensor :
77- """
78- Samples a random non weakly-stationary matrix A of shape [m, n] with provided rank.
79-
80- This is done by sampling a positive u, and by then sampling a matrix A that has u as one of its
81- left-singular vectors, with positive singular value s. Any 0 <= v, v != 0, satisfies v^T A != 0.
82- Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a contradiction. A is thus not
83- weakly stationary.
84- """
85-
86- assert 0 < rank <= min (m , n )
87-
88- u = torch .abs (torch .randn ([m ]))
89- U1 = normalize (u , dim = 0 ).unsqueeze (1 )
90- U2 = _sample_semi_orthonormal_complement (U1 )
91- U = torch .hstack ([U1 , U2 ])
92- Vt = _sample_orthonormal_matrix (n )
93- S = torch .diag (torch .abs (torch .randn ([rank ])))
94- A = U [:, :rank ] @ S @ Vt [:rank , :]
95- return A
96-
97-
98- def _sample_orthonormal_matrix (dim : int ) -> Tensor :
99- """Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
100-
101- return _sample_semi_orthonormal_complement (torch .zeros ([dim , 0 ]))
102-
103-
104- def _sample_semi_orthonormal_complement (Q : Tensor ) -> Tensor :
105- """
106- Uniformly samples a random semi-orthonormal matrix Q' (i.e. Q'^T Q' = I) of shape [m, m-k]
107- orthogonal to Q, i.e. such that the concatenation [Q, Q'] is an orthonormal matrix.
108-
109- :param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
110- """
111-
112- m , k = Q .shape
113- A = torch .randn ([m , m - k ])
114-
115- # project A onto the orthogonal complement of Q
116- A_proj = A - Q @ (Q .T @ A )
117-
118- Q_prime , _ = torch .linalg .qr (A_proj )
119- return Q_prime
2+ from unit .conftest import DEVICE
1203
4+ from ._matrix_samplers import NonWeakSampler , NormalSampler , StrictlyWeakSampler , StrongSampler
1215
1226_normal_dims = [
1237 (1 , 1 , 1 ),
@@ -144,14 +28,13 @@ def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
14428
14529_scales = [0.0 , 1e-10 , 1e3 , 1e5 , 1e10 , 1e15 ]
14630
147- # Fix seed to fix randomness of matrix sampling
148- torch .manual_seed (0 )
31+ _rng = torch .Generator (device = DEVICE ).manual_seed (0 )
14932
150- matrices = [_sample_matrix (m , n , r ) for m , n , r in _normal_dims ]
33+ matrices = [NormalSampler (m , n , r )( _rng ) for m , n , r in _normal_dims ]
15134zero_matrices = [torch .zeros ([m , n ]) for m , n , _ in _zero_dims ]
152- strong_matrices = [_sample_strong_matrix (m , n , r ) for m , n , r in _stationarity_dims ]
153- strictly_weak_matrices = [_sample_strictly_weak_matrix (m , n , r ) for m , n , r in _stationarity_dims ]
154- non_weak_matrices = [_sample_non_weak_matrix (m , n , r ) for m , n , r in _stationarity_dims ]
35+ strong_matrices = [StrongSampler (m , n , r )( _rng ) for m , n , r in _stationarity_dims ]
36+ strictly_weak_matrices = [StrictlyWeakSampler (m , n , r )( _rng ) for m , n , r in _stationarity_dims ]
37+ non_weak_matrices = [NonWeakSampler (m , n , r )( _rng ) for m , n , r in _stationarity_dims ]
15538
15639scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices ]
15740
@@ -170,4 +53,4 @@ def _sample_semi_orthonormal_complement(Q: Tensor) -> Tensor:
17053 (9 , 11 , 5 ),
17154 (9 , 11 , 9 ),
17255]
173- nash_mtl_matrices = [_sample_matrix (m , n , r ) for m , n , r in _nashmtl_dims ]
56+ nash_mtl_matrices = [NormalSampler (m , n , r )( _rng ) for m , n , r in _nashmtl_dims ]
0 commit comments