Skip to content

Commit ea5a1f6

Browse files
committed
Add stationarity properties
1 parent 728413b commit ea5a1f6

File tree

1 file changed

+81
-1
lines changed

1 file changed

+81
-1
lines changed

tests/unit/aggregation/_property_testers.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
from torchjd.aggregation import Aggregator
77

8-
from ._inputs import scaled_matrices, typical_matrices
8+
from ._inputs import (
9+
matrices,
10+
scaled_matrices,
11+
strong_stationary_matrices,
12+
typical_matrices,
13+
weak_stationary_matrices,
14+
)
915

1016

1117
class ExpectedStructureProperty:
@@ -102,3 +108,77 @@ def _assert_linear_under_scaling_property(
102108
expected = alpha * x1 + beta * x2
103109

104110
assert_close(x, expected, atol=8e-03, rtol=0)
111+
112+
113+
class StationarityProperty:
114+
"""
115+
This class tests empirically that a given `Aggregator` satisfies the stationarity property.
116+
"""
117+
118+
@staticmethod
119+
def _assert_stationarity_property(
120+
aggregator: Aggregator,
121+
stationary_matrix: Tensor,
122+
) -> None:
123+
vector = aggregator(stationary_matrix)
124+
norm = vector.norm().item()
125+
assert norm < 8e-02
126+
127+
@staticmethod
128+
def _assert_non_stationarity_property(
129+
aggregator: Aggregator,
130+
non_stationary_matrix: Tensor,
131+
) -> None:
132+
vector = aggregator(non_stationary_matrix)
133+
norm = vector.norm().item()
134+
assert norm > 1e-03
135+
136+
137+
class StrongStationarityProperty(StationarityProperty):
138+
139+
@classmethod
140+
@mark.parametrize("stationary_matrix", strong_stationary_matrices)
141+
def test_stationarity_property(
142+
cls,
143+
aggregator: Aggregator,
144+
stationary_matrix: Tensor,
145+
):
146+
super(StrongStationarityProperty, cls)._assert_stationarity_property(
147+
aggregator, stationary_matrix
148+
)
149+
150+
@classmethod
151+
@mark.parametrize("non_stationary_matrix", weak_stationary_matrices + matrices)
152+
def test_non_stationarity_property(
153+
cls,
154+
aggregator: Aggregator,
155+
non_stationary_matrix: Tensor,
156+
):
157+
super(StrongStationarityProperty, cls)._assert_non_stationarity_property(
158+
aggregator, non_stationary_matrix
159+
)
160+
161+
162+
class WeakStationarityProperty(StationarityProperty):
163+
164+
@classmethod
165+
@mark.parametrize("stationary_matrix", strong_stationary_matrices + weak_stationary_matrices)
166+
def test_stationarity_property(
167+
cls,
168+
aggregator: Aggregator,
169+
stationary_matrix: Tensor,
170+
):
171+
super(WeakStationarityProperty, cls)._assert_stationarity_property(
172+
aggregator, stationary_matrix
173+
)
174+
175+
@classmethod
176+
@mark.parametrize("non_stationary_matrix", matrices)
177+
def test_non_stationarity_property(
178+
cls,
179+
aggregator: Aggregator,
180+
non_stationary_matrix: Tensor,
181+
):
182+
super(WeakStationarityProperty, cls)._assert_non_stationarity_property(
183+
aggregator, non_stationary_matrix
184+
)

0 commit comments

Comments
 (0)