Skip to content

Commit 4142861

Browse files
committed
make utility functions protected
1 parent 7b4cfb7 commit 4142861

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

tests/unit/aggregation/_inputs.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import Tensor
33

44

5-
def generate_orthogonal_matrix(dim: int) -> Tensor:
5+
def _generate_orthogonal_matrix(dim: int) -> Tensor:
66
"""
77
Uniformly generates a random orthogonal matrix of shape [n, n].
88
"""
@@ -12,7 +12,7 @@ def generate_orthogonal_matrix(dim: int) -> Tensor:
1212
return Q
1313

1414

15-
def complete_orthogonal_matrix(vector: Tensor) -> Tensor:
15+
def _complete_orthogonal_matrix(vector: Tensor) -> Tensor:
1616
"""
1717
Uniformly generates a random orthogonal matrix of shape [len(vector), len(vector)] such that the
1818
first column is the normalization of the provided vector.
@@ -29,51 +29,51 @@ def complete_orthogonal_matrix(vector: Tensor) -> Tensor:
2929
return torch.cat([u.unsqueeze(1), Q], dim=1)
3030

3131

32-
def generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
32+
def _generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
3333
"""
3434
Generates a random matrix of shape [``n_rows``, ``n_cols``] with provided ``rank``.
3535
"""
3636

37-
U = generate_orthogonal_matrix(n_rows)
38-
Vt = generate_orthogonal_matrix(n_cols)
37+
U = _generate_orthogonal_matrix(n_rows)
38+
Vt = _generate_orthogonal_matrix(n_cols)
3939
S = torch.diag(torch.abs(torch.randn([rank])))
4040
matrix = U[:, :rank] @ S @ Vt[:rank, :]
4141
return matrix
4242

4343

44-
def generate_matrix_with_orthogonal_vector(vector: Tensor, n_cols: int, rank: int) -> Tensor:
44+
def _generate_matrix_with_orthogonal_vector(vector: Tensor, n_cols: int, rank: int) -> Tensor:
4545
"""
4646
Generates a random matrix of shape [``len(vector)``, ``n_cols``] with rank
4747
``min(rank, len(vector)-1)``. Such that `vector @ matrix` is zero.
4848
"""
4949

5050
n_rows = len(vector)
5151
effective_rank = min(rank, n_rows - 1)
52-
U = complete_orthogonal_matrix(vector)
53-
Vt = generate_orthogonal_matrix(n_cols)
52+
U = _complete_orthogonal_matrix(vector)
53+
Vt = _generate_orthogonal_matrix(n_cols)
5454
S = torch.diag(torch.abs(torch.randn([effective_rank])))
5555
matrix = U[:, 1 : 1 + effective_rank] @ S @ Vt[:effective_rank, :]
5656
return matrix
5757

5858

59-
def generate_strong_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
59+
def _generate_strong_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
6060
"""
6161
Generates a random matrix of shape [``n_rows``, ``n_cols``] with rank
6262
``min(rank, len(vector)-1)``, such that there exists a vector `0<v` with `v @ matrix=0`.
6363
"""
6464
v = torch.abs(torch.randn([n_rows]))
65-
return generate_matrix_with_orthogonal_vector(v, n_cols, rank)
65+
return _generate_matrix_with_orthogonal_vector(v, n_cols, rank)
6666

6767

68-
def generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
68+
def _generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
6969
"""
7070
Generates a random matrix of shape [``n_rows``, ``n_cols``] with rank
7171
``min(rank, len(vector)-1)``, such that there exists a vector `0<=v` with at least one
7272
coordinate equal to `0` and such that `v @ matrix=0`.
7373
"""
7474
v = torch.abs(torch.randn([n_rows]))
7575
v[torch.randint(0, n_rows, [])] = 0.0
76-
return generate_matrix_with_orthogonal_vector(v, n_cols, rank)
76+
return _generate_matrix_with_orthogonal_vector(v, n_cols, rank)
7777

7878

7979
_matrix_dimension_triples = [
@@ -97,7 +97,7 @@ def generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tens
9797
torch.manual_seed(0)
9898

9999
matrices = [
100-
generate_matrix(n_rows, n_cols, rank) for n_rows, n_cols, rank in _matrix_dimension_triples
100+
_generate_matrix(n_rows, n_cols, rank) for n_rows, n_cols, rank in _matrix_dimension_triples
101101
]
102102
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
103103
zero_matrices = [torch.zeros([n_rows, n_cols]) for n_rows, n_cols in _zero_matrices_shapes]
@@ -106,10 +106,10 @@ def generate_weak_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tens
106106
matrix for matrix in scaled_matrices + zero_matrices if matrix.shape[0] >= 2
107107
]
108108
strong_stationary_matrices = [
109-
generate_strong_stationary_matrix(n_rows, n_cols, rank)
109+
_generate_strong_stationary_matrix(n_rows, n_cols, rank)
110110
for n_rows, n_cols, rank in _matrix_dimension_triples
111111
]
112112
weak_stationary_matrices = strong_stationary_matrices + [
113-
generate_weak_stationary_matrix(n_rows, n_cols, rank)
113+
_generate_weak_stationary_matrix(n_rows, n_cols, rank)
114114
for n_rows, n_cols, rank in _matrix_dimension_triples
115115
]

0 commit comments

Comments
 (0)