Skip to content

Commit 6e84fec

Browse files
committed
Use correct stationary matrices in strong stationary property tester
1 parent 26fc65a commit 6e84fec

1 file changed

Lines changed: 10 additions & 6 deletions

File tree

tests/unit/aggregation/_property_testers.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
from torchjd.aggregation import Aggregator
77

8-
from ._inputs import matrices, scaled_matrices, typical_matrices, weak_stationary_matrices
8+
from ._inputs import (
9+
non_strong_stationary_matrices,
10+
scaled_matrices,
11+
typical_matrices,
12+
)
913

1014

1115
class ExpectedStructureProperty:
@@ -113,14 +117,14 @@ class StrongStationarityProperty:
113117
"""
114118

115119
@classmethod
116-
@mark.parametrize("stationary_matrix", weak_stationary_matrices + matrices)
117-
def test_stationarity_property(cls, aggregator: Aggregator, stationary_matrix: Tensor):
118-
cls._assert_stationarity_property(aggregator, stationary_matrix)
120+
@mark.parametrize("stationary_matrix", non_strong_stationary_matrices)
121+
def test_stationarity_property(cls, aggregator: Aggregator, matrix: Tensor):
122+
cls._assert_stationarity_property(aggregator, matrix)
119123

120124
@staticmethod
121125
def _assert_stationarity_property(
122-
aggregator: Aggregator, non_stationary_matrix: Tensor
126+
aggregator: Aggregator, matrix: Tensor
123127
) -> None:
124-
vector = aggregator(non_stationary_matrix)
128+
vector = aggregator(matrix)
125129
norm = vector.norm().item()
126130
assert norm > 1e-03

0 commit comments

Comments
 (0)