33import torch
44from torch import Tensor
55from torch .nn .functional import normalize
6+ from unit ._utils import randint_ , randn_ , randperm_ , zeros_
67
78
89class MatrixSampler (ABC ):
@@ -43,7 +44,7 @@ class NormalSampler(MatrixSampler):
4344 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
4445 U = _sample_orthonormal_matrix (self .m , dtype = self .dtype , rng = rng )
4546 Vt = _sample_orthonormal_matrix (self .n , dtype = self .dtype , rng = rng )
46- S = torch .diag (torch .abs (torch . randn ([self .rank ], dtype = self .dtype , generator = rng )))
47+ S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self .dtype , generator = rng )))
4748 A = U [:, : self .rank ] @ S @ Vt [: self .rank , :]
4849 return A
4950
@@ -65,11 +66,11 @@ def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
6566 assert 0 < rank <= min (m - 1 , n )
6667
6768 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
68- v = torch .abs (torch . randn ([self .m ], dtype = self .dtype , generator = rng ))
69+ v = torch .abs (randn_ ([self .m ], dtype = self .dtype , generator = rng ))
6970 U1 = normalize (v , dim = 0 ).unsqueeze (1 )
7071 U2 = _sample_semi_orthonormal_complement (U1 , rng = rng )
7172 Vt = _sample_orthonormal_matrix (self .n , dtype = self .dtype , rng = rng )
72- S = torch .diag (torch .abs (torch . randn ([self .rank ], dtype = self .dtype , generator = rng )))
73+ S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self .dtype , generator = rng )))
7374 A = U2 [:, : self .rank ] @ S @ Vt [: self .rank , :]
7475 return A
7576
@@ -101,18 +102,18 @@ def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
101102 assert 0 < rank <= min (m - 1 , n )
102103
103104 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
104- u = torch .abs (torch . randn ([self .m ], dtype = self .dtype , generator = rng ))
105- split_index = torch . randint (1 , self .m , [], generator = rng ).item ()
106- shuffled_range = torch . randperm (self .m , generator = rng )
107- v = torch . zeros (self .m , dtype = self .dtype )
105+ u = torch .abs (randn_ ([self .m ], dtype = self .dtype , generator = rng ))
106+ split_index = randint_ (1 , self .m , [], generator = rng ).item ()
107+ shuffled_range = randperm_ (self .m , generator = rng )
108+ v = zeros_ (self .m , dtype = self .dtype )
108109 v [shuffled_range [:split_index ]] = normalize (u [shuffled_range [:split_index ]], dim = 0 )
109- v_prime = torch . zeros (self .m , dtype = self .dtype )
110+ v_prime = zeros_ (self .m , dtype = self .dtype )
110111 v_prime [shuffled_range [split_index :]] = normalize (u [shuffled_range [split_index :]], dim = 0 )
111112 U1 = torch .stack ([v , v_prime ]).T
112113 U2 = _sample_semi_orthonormal_complement (U1 , rng = rng )
113114 U = torch .hstack ([U1 , U2 ])
114115 Vt = _sample_orthonormal_matrix (self .n , dtype = self .dtype , rng = rng )
115- S = torch .diag (torch .abs (torch . randn ([self .rank ], dtype = self .dtype , generator = rng )))
116+ S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self .dtype , generator = rng )))
116117 A = U [:, 1 : self .rank + 1 ] @ S @ Vt [: self .rank , :]
117118 return A
118119
@@ -133,12 +134,12 @@ def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
133134 assert 0 < rank
134135
135136 def __call__ (self , rng : torch .Generator | None = None ) -> Tensor :
136- u = torch .abs (torch . randn ([self .m ], dtype = self .dtype , generator = rng ))
137+ u = torch .abs (randn_ ([self .m ], dtype = self .dtype , generator = rng ))
137138 U1 = normalize (u , dim = 0 ).unsqueeze (1 )
138139 U2 = _sample_semi_orthonormal_complement (U1 , rng = rng )
139140 U = torch .hstack ([U1 , U2 ])
140141 Vt = _sample_orthonormal_matrix (self .n , dtype = self .dtype , rng = rng )
141- S = torch .diag (torch .abs (torch . randn ([self .rank ], dtype = self .dtype , generator = rng )))
142+ S = torch .diag (torch .abs (randn_ ([self .rank ], dtype = self .dtype , generator = rng )))
142143 A = U [:, : self .rank ] @ S @ Vt [: self .rank , :]
143144 return A
144145
@@ -148,7 +149,7 @@ def _sample_orthonormal_matrix(
148149) -> Tensor :
149150 """Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
150151
151- return _sample_semi_orthonormal_complement (torch . zeros ([dim , 0 ], dtype = dtype ), rng = rng )
152+ return _sample_semi_orthonormal_complement (zeros_ ([dim , 0 ], dtype = dtype ), rng = rng )
152153
153154
154155def _sample_semi_orthonormal_complement (Q : Tensor , rng : torch .Generator | None = None ) -> Tensor :
@@ -161,7 +162,7 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None =
161162
162163 dtype = Q .dtype
163164 m , k = Q .shape
164- A = torch . randn ([m , m - k ], dtype = dtype , generator = rng )
165+ A = randn_ ([m , m - k ], dtype = dtype , generator = rng )
165166
166167 # project A onto the orthogonal complement of Q
167168 A_proj = A - Q @ (Q .T @ A )
0 commit comments