11from abc import ABC , abstractmethod
22
33import torch
4- from settings import DTYPE
54from torch import Tensor
65from torch .nn .functional import normalize
76from utils .tensors import randint_ , randn_ , randperm_ , zeros_
87
98
109class MatrixSampler (ABC ):
11- """Abstract base class for sampling matrices of a given shape, rank and dtype ."""
10+ """Abstract base class for sampling matrices of a given shape, rank."""
1211
13- def __init__ (self , m : int , n : int , rank : int , dtype : torch . dtype = DTYPE ):
14- self ._check_params (m , n , rank , dtype )
12+ def __init__ (self , m : int , n : int , rank : int ):
13+ self ._check_params (m , n , rank )
1514 self .m = m
1615 self .n = n
1716 self .rank = rank
18- self .dtype = dtype
1917
20- def _check_params (self , m : int , n : int , rank : int , dtype : torch . dtype ) -> None :
18+ def _check_params (self , m : int , n : int , rank : int ) -> None :
2119 """Checks that the provided __init__ parameters are acceptable."""
2220
2321 assert m >= 0
2422 assert n >= 0
2523 assert 0 <= rank <= min (m , n )
26- assert dtype in {torch .float32 , torch .float64 }
2724
2825 @abstractmethod
2926 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
@@ -35,24 +32,24 @@ def __repr__(self) -> str:
3532 def __str__ (self ) -> str :
3633 return (
3734 f"{ self .__class__ .__name__ .replace ('MatrixSampler' , '' )} "
38- f"({ self .m } x{ self .n } r{ self .rank } : { str ( self . dtype )[ 6 :] } )"
35+ f"({ self .m } x{ self .n } r{ self .rank } )"
3936 )
4037
4138
4239class NormalSampler (MatrixSampler ):
43- """Sampler for random normal matrices of shape [m, n] with provided rank and dtype ."""
40+ """Sampler for random normal matrices of shape [m, n] with provided rank."""
4441
4542 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
46- U = _sample_orthonormal_matrix (self .m , dtype = self . dtype , rng = rng )
47- Vt = _sample_orthonormal_matrix (self .n , dtype = self . dtype , rng = rng )
48- S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self . dtype , generator = rng )))
43+ U = _sample_orthonormal_matrix (self .m , rng = rng )
44+ Vt = _sample_orthonormal_matrix (self .n , rng = rng )
45+ S = torch .diag (torch .abs (randn_ ([self .rank ], generator = rng )))
4946 A = U [:, : self .rank ] @ S @ Vt [: self .rank , :]
5047 return A
5148
5249
5350class StrongSampler (MatrixSampler ):
5451 """
55- Sampler for random strongly stationary matrices of shape [m, n] with provided rank and dtype .
52+ Sampler for random strongly stationary matrices of shape [m, n] with provided rank.
5653
5754 Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
5855 that v^T A = 0.
@@ -61,25 +58,24 @@ class StrongSampler(MatrixSampler):
6158 orthogonal to v.
6259 """
6360
64- def _check_params (self , m : int , n : int , rank : int , dtype : torch . dtype ) -> None :
65- super ()._check_params (m , n , rank , dtype )
61+ def _check_params (self , m : int , n : int , rank : int ) -> None :
62+ super ()._check_params (m , n , rank )
6663 assert 1 < m
6764 assert 0 < rank <= min (m - 1 , n )
6865
6966 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
70- v = torch .abs (randn_ ([self .m ], dtype = self . dtype , generator = rng ))
67+ v = torch .abs (randn_ ([self .m ], generator = rng ))
7168 U1 = normalize (v , dim = 0 ).unsqueeze (1 )
7269 U2 = _sample_semi_orthonormal_complement (U1 , rng = rng )
73- Vt = _sample_orthonormal_matrix (self .n , dtype = self . dtype , rng = rng )
74- S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self . dtype , generator = rng )))
70+ Vt = _sample_orthonormal_matrix (self .n , rng = rng )
71+ S = torch .diag (torch .abs (randn_ ([self .rank ], generator = rng )))
7572 A = U2 [:, : self .rank ] @ S @ Vt [: self .rank , :]
7673 return A
7774
7875
7976class StrictlyWeakSampler (MatrixSampler ):
8077 """
81- Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank and
82- dtype.
78+ Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank.
8379
8480 Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
8581 such that v^T A = 0.
@@ -97,60 +93,57 @@ class StrictlyWeakSampler(MatrixSampler):
9793 stationary.
9894 """
9995
100- def _check_params (self , m : int , n : int , rank : int , dtype : torch . dtype ) -> None :
101- super ()._check_params (m , n , rank , dtype )
96+ def _check_params (self , m : int , n : int , rank : int ) -> None :
97+ super ()._check_params (m , n , rank )
10298 assert 1 < m
10399 assert 0 < rank <= min (m - 1 , n )
104100
105101 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
106- u = torch .abs (randn_ ([self .m ], dtype = self . dtype , generator = rng ))
102+ u = torch .abs (randn_ ([self .m ], generator = rng ))
107103 split_index = randint_ (1 , self .m , [], generator = rng ).item ()
108104 shuffled_range = randperm_ (self .m , generator = rng )
109- v = zeros_ (self .m , dtype = self . dtype )
105+ v = zeros_ (self .m )
110106 v [shuffled_range [:split_index ]] = normalize (u [shuffled_range [:split_index ]], dim = 0 )
111- v_prime = zeros_ (self .m , dtype = self . dtype )
107+ v_prime = zeros_ (self .m )
112108 v_prime [shuffled_range [split_index :]] = normalize (u [shuffled_range [split_index :]], dim = 0 )
113109 U1 = torch .stack ([v , v_prime ]).T
114110 U2 = _sample_semi_orthonormal_complement (U1 , rng = rng )
115111 U = torch .hstack ([U1 , U2 ])
116- Vt = _sample_orthonormal_matrix (self .n , dtype = self . dtype , rng = rng )
117- S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self . dtype , generator = rng )))
112+ Vt = _sample_orthonormal_matrix (self .n , rng = rng )
113+ S = torch .diag (torch .abs (randn_ ([self .rank ], generator = rng )))
118114 A = U [:, 1 : self .rank + 1 ] @ S @ Vt [: self .rank , :]
119115 return A
120116
121117
122118class NonWeakSampler (MatrixSampler ):
123119 """
124- Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank and
125- dtype.
120+ Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank.
126121
127122 Obtaining such a matrix is done by sampling a positive u, and by then sampling a matrix A that
128123 has u as one of its left-singular vectors, with positive singular value s. Any 0 <= v, v != 0,
129124 satisfies v^T A != 0. Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a
130125 contradiction. A is thus not weakly stationary.
131126 """
132127
133- def _check_params (self , m : int , n : int , rank : int , dtype : torch . dtype ) -> None :
134- super ()._check_params (m , n , rank , dtype )
128+ def _check_params (self , m : int , n : int , rank : int ) -> None :
129+ super ()._check_params (m , n , rank )
135130 assert 0 < rank
136131
137132 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
138- u = torch .abs (randn_ ([self .m ], dtype = self . dtype , generator = rng ))
133+ u = torch .abs (randn_ ([self .m ], generator = rng ))
139134 U1 = normalize (u , dim = 0 ).unsqueeze (1 )
140135 U2 = _sample_semi_orthonormal_complement (U1 , rng = rng )
141136 U = torch .hstack ([U1 , U2 ])
142- Vt = _sample_orthonormal_matrix (self .n , dtype = self . dtype , rng = rng )
143- S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self . dtype , generator = rng )))
137+ Vt = _sample_orthonormal_matrix (self .n , rng = rng )
138+ S = torch .diag (torch .abs (randn_ ([self .rank ], generator = rng )))
144139 A = U [:, : self .rank ] @ S @ Vt [: self .rank , :]
145140 return A
146141
147142
148- def _sample_orthonormal_matrix (
149- dim : int , dtype : torch .dtype , rng : torch .Generator | None = None
150- ) -> Tensor :
143+ def _sample_orthonormal_matrix (dim : int , rng : torch .Generator | None = None ) -> Tensor :
151144 """Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
152145
153- return _sample_semi_orthonormal_complement (zeros_ ([dim , 0 ], dtype = dtype ), rng = rng )
146+ return _sample_semi_orthonormal_complement (zeros_ ([dim , 0 ]), rng = rng )
154147
155148
156149def _sample_semi_orthonormal_complement (Q : Tensor , rng : torch .Generator | None = None ) -> Tensor :
@@ -161,9 +154,8 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None =
161154 :param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
162155 """
163156
164- dtype = Q .dtype
165157 m , k = Q .shape
166- A = randn_ ([m , m - k ], dtype = dtype , generator = rng )
158+ A = randn_ ([m , m - k ], generator = rng )
167159
168160 # project A onto the orthogonal complement of Q
169161 A_proj = A - Q @ (Q .T @ A )
0 commit comments