@@ -41,39 +41,39 @@ def _generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
4141 return matrix
4242
4343
44- def _generate_matrix_with_orthogonal_vector (vector : Tensor , n_cols : int , rank : int ) -> Tensor :
44+ def _generate_matrix_with_orthogonal_vector (vector : Tensor , n_cols : int ) -> Tensor :
4545 """
4646 Generates a random matrix of shape [``len(vector)``, ``n_cols``] with rank
4747 ``min(rank, len(vector)-1)``. Such that `vector @ matrix` is zero.
4848 """
4949
5050 n_rows = len (vector )
51- effective_rank = min (rank , n_rows - 1 )
51+ rank = min (n_cols , n_rows - 1 )
5252 U = _complete_orthogonal_matrix (vector )
5353 Vt = _generate_orthogonal_matrix (n_cols )
54- S = torch .diag (torch .abs (torch .randn ([effective_rank ])))
55- matrix = U [:, 1 : 1 + effective_rank ] @ S @ Vt [:effective_rank , :]
54+ S = torch .diag (torch .abs (torch .randn ([rank ])))
55+ matrix = U [:, 1 : 1 + rank ] @ S @ Vt [:rank , :]
5656 return matrix
5757
5858
59- def _generate_strong_stationary_matrix (n_rows : int , n_cols : int , rank : int ) -> Tensor :
59+ def _generate_strong_stationary_matrix (n_rows : int , n_cols : int ) -> Tensor :
6060 """
6161 Generates a random matrix of shape [``n_rows``, ``n_cols``] with rank
6262 ``min(rank, len(vector)-1)``, such that there exists a vector `0<v` with `v @ matrix=0`.
6363 """
6464 v = torch .abs (torch .randn ([n_rows ]))
65- return _generate_matrix_with_orthogonal_vector (v , n_cols , rank )
65+ return _generate_matrix_with_orthogonal_vector (v , n_cols )
6666
6767
68- def _generate_weak_stationary_matrix (n_rows : int , n_cols : int , rank : int ) -> Tensor :
68+ def _generate_weak_stationary_matrix (n_rows : int , n_cols : int ) -> Tensor :
6969 """
7070 Generates a random matrix of shape [``n_rows``, ``n_cols``] with rank
7171 ``min(rank, len(vector)-1)``, such that there exists a vector `0<=v` with at least one
7272 coordinate equal to `0` and such that `v @ matrix=0`.
7373 """
7474 v = torch .abs (torch .randn ([n_rows ]))
7575 v [torch .randint (0 , n_rows , [])] = 0.0
76- return _generate_matrix_with_orthogonal_vector (v , n_cols , rank )
76+ return _generate_matrix_with_orthogonal_vector (v , n_cols )
7777
7878
7979_matrix_dimension_triples = [
@@ -91,6 +91,12 @@ def _generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Ten
9191 (9 , 11 ),
9292]
9393
94+ _stationary_matrices_shapes = [
95+ (1 , 1 ),
96+ (5 , 3 ),
97+ (9 , 11 ),
98+ ]
99+
94100_scales = [0.0 , 1e-10 , 1.0 , 1e3 , 1e5 , 1e10 , 1e15 ]
95101
96102# Fix seed to fix randomness of matrix generation
@@ -106,10 +112,10 @@ def _generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Ten
106112 matrix for matrix in scaled_matrices + zero_matrices if matrix .shape [0 ] >= 2
107113]
108114strong_stationary_matrices = [
109- _generate_strong_stationary_matrix (n_rows , n_cols , rank )
110- for n_rows , n_cols , rank in _matrix_dimension_triples
115+ _generate_strong_stationary_matrix (n_rows , n_cols )
116+ for n_rows , n_cols in _stationary_matrices_shapes
111117]
112118weak_stationary_matrices = strong_stationary_matrices + [
113- _generate_weak_stationary_matrix (n_rows , n_cols , rank )
114- for n_rows , n_cols , rank in _matrix_dimension_triples
119+ _generate_weak_stationary_matrix (n_rows , n_cols )
120+ for n_rows , n_cols in _stationary_matrices_shapes
115121]
0 commit comments