Skip to content

Commit 11c2c2c

Browse files
committed
fixes
1 parent bd2476d commit 11c2c2c

3 files changed

Lines changed: 5 additions & 6 deletions

File tree

tests/unit/aggregation/_inputs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _generate_strong_stationary_matrix(m: int, n: int, rank: int) -> Tensor:
2525
U = torch.hstack([U1, U2])
2626
Vt = _generate_orthonormal_matrix(n)
2727
S = torch.diag(torch.abs(torch.randn([rank])))
28-
A = U[:, 1 : rank + 1] @ S @ Vt[1 : rank + 1, :]
28+
A = U[:, 1 : rank + 1] @ S @ Vt[:rank, :]
2929
return A
3030

3131

@@ -46,7 +46,7 @@ def _generate_weak_non_strong_stationary_matrix(m: int, n: int, rank: int) -> Te
4646
U = torch.hstack([U1, U2])
4747
Vt = _generate_orthonormal_matrix(n)
4848
S = torch.diag(torch.abs(torch.randn([rank])))
49-
A = U[:, 1 : rank + 1] @ S @ Vt[1 : rank + 1, :]
49+
A = U[:, 1 : rank + 1] @ S @ Vt[:rank, :]
5050
return A
5151

5252

@@ -109,7 +109,6 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
109109
(5, 3, 3),
110110
(9, 11, 6),
111111
(7, 13, 2),
112-
(3, 5, 3),
113112
]
114113

115114
_scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]

tests/unit/aggregation/_property_testers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class StrongStationarityProperty:
113113
"""
114114

115115
@classmethod
116-
@mark.parametrize("stationary_matrix", non_strong_stationary_matrices)
116+
@mark.parametrize("matrix", non_strong_stationary_matrices)
117117
def test_stationarity_property(cls, aggregator: Aggregator, matrix: Tensor):
118118
cls._assert_stationarity_property(aggregator, matrix)
119119

tests/unit/aggregation/test_constant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor
5454

5555
@classmethod
5656
@mark.parametrize(["aggregator", "matrix"], zip(_aggregators_3, _matrices_3))
57-
def test_stationarity_property(cls, aggregator: Constant, non_stationary_matrix: Tensor):
58-
cls._assert_stationarity_property(aggregator, non_stationary_matrix)
57+
def test_stationarity_property(cls, aggregator: Constant, matrix: Tensor):
58+
cls._assert_stationarity_property(aggregator, matrix)
5959

6060

6161
@mark.parametrize(

0 commit comments

Comments
 (0)