Skip to content

Commit 088c32f

Browse files
committed
Stationary matrices are now of rank exactly of rank=min(n_cols, n_rows - 1).
The reason for that is we want to make sure for weak stationary that there is no 0<w with J^Tw=0.
1 parent 4142861 commit 088c32f

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

tests/unit/aggregation/_inputs.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,39 @@ def _generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
4141
return matrix
4242

4343

44-
def _generate_matrix_with_orthogonal_vector(vector: Tensor, n_cols: int, rank: int) -> Tensor:
44+
def _generate_matrix_with_orthogonal_vector(vector: Tensor, n_cols: int) -> Tensor:
4545
"""
4646
Generates a random matrix of shape [``len(vector)``, ``n_cols``] with rank
4747
``min(rank, len(vector)-1)``. Such that `vector @ matrix` is zero.
4848
"""
4949

5050
n_rows = len(vector)
51-
effective_rank = min(rank, n_rows - 1)
51+
rank = min(n_cols, n_rows - 1)
5252
U = _complete_orthogonal_matrix(vector)
5353
Vt = _generate_orthogonal_matrix(n_cols)
54-
S = torch.diag(torch.abs(torch.randn([effective_rank])))
55-
matrix = U[:, 1 : 1 + effective_rank] @ S @ Vt[:effective_rank, :]
54+
S = torch.diag(torch.abs(torch.randn([rank])))
55+
matrix = U[:, 1 : 1 + rank] @ S @ Vt[:rank, :]
5656
return matrix
5757

5858

59-
def _generate_strong_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
59+
def _generate_strong_stationary_matrix(n_rows: int, n_cols: int) -> Tensor:
6060
"""
6161
Generates a random matrix of shape [``n_rows``, ``n_cols``] with rank
6262
``min(rank, len(vector)-1)``, such that there exists a vector `0<v` with `v @ matrix=0`.
6363
"""
6464
v = torch.abs(torch.randn([n_rows]))
65-
return _generate_matrix_with_orthogonal_vector(v, n_cols, rank)
65+
return _generate_matrix_with_orthogonal_vector(v, n_cols)
6666

6767

68-
def _generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
68+
def _generate_weak_stationary_matrix(n_rows: int, n_cols: int) -> Tensor:
6969
"""
7070
Generates a random matrix of shape [``n_rows``, ``n_cols``] with rank
7171
``min(rank, len(vector)-1)``, such that there exists a vector `0<=v` with at least one
7272
coordinate equal to `0` and such that `v @ matrix=0`.
7373
"""
7474
v = torch.abs(torch.randn([n_rows]))
7575
v[torch.randint(0, n_rows, [])] = 0.0
76-
return _generate_matrix_with_orthogonal_vector(v, n_cols, rank)
76+
return _generate_matrix_with_orthogonal_vector(v, n_cols)
7777

7878

7979
_matrix_dimension_triples = [
@@ -91,6 +91,12 @@ def _generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Ten
9191
(9, 11),
9292
]
9393

94+
_stationary_matrices_shapes = [
95+
(1, 1),
96+
(5, 3),
97+
(9, 11),
98+
]
99+
94100
_scales = [0.0, 1e-10, 1.0, 1e3, 1e5, 1e10, 1e15]
95101

96102
# Fix seed to fix randomness of matrix generation
@@ -106,10 +112,10 @@ def _generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Ten
106112
matrix for matrix in scaled_matrices + zero_matrices if matrix.shape[0] >= 2
107113
]
108114
strong_stationary_matrices = [
109-
_generate_strong_stationary_matrix(n_rows, n_cols, rank)
110-
for n_rows, n_cols, rank in _matrix_dimension_triples
115+
_generate_strong_stationary_matrix(n_rows, n_cols)
116+
for n_rows, n_cols in _stationary_matrices_shapes
111117
]
112118
weak_stationary_matrices = strong_stationary_matrices + [
113-
_generate_weak_stationary_matrix(n_rows, n_cols, rank)
114-
for n_rows, n_cols, rank in _matrix_dimension_triples
119+
_generate_weak_stationary_matrix(n_rows, n_cols)
120+
for n_rows, n_cols in _stationary_matrices_shapes
115121
]

0 commit comments

Comments
 (0)