Skip to content

Commit e7e1c29

Browse files
PierreQuintonValerianRey
authored andcommitted
Implements the stationarity properties
1 parent 480705b commit e7e1c29

File tree

4 files changed

+91
-3
lines changed

4 files changed

+91
-3
lines changed

tests/unit/aggregation/_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
108108
strong_stationary_matrices = [
109109
_generate_strong_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
110110
]
111-
weak_stationary_matrices = strong_stationary_matrices + [
111+
weak_stationary_matrices = [
112112
_generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
113113
]
114114
typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices

tests/unit/aggregation/_property_testers.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
from torchjd.aggregation import Aggregator
77

8-
from ._inputs import scaled_matrices, typical_matrices
8+
from ._inputs import (
9+
matrices,
10+
scaled_matrices,
11+
strong_stationary_matrices,
12+
typical_matrices,
13+
weak_stationary_matrices,
14+
)
915

1016

1117
class ExpectedStructureProperty:
@@ -102,3 +108,77 @@ def _assert_linear_under_scaling_property(
102108
expected = alpha * x1 + beta * x2
103109

104110
assert_close(x, expected, atol=8e-03, rtol=0)
111+
112+
113+
class StationarityProperty:
114+
"""
115+
This class tests empirically that a given `Aggregator` satisfies the stationarity property.
116+
"""
117+
118+
@staticmethod
119+
def _assert_stationarity_property(
120+
aggregator: Aggregator,
121+
stationary_matrix: Tensor,
122+
) -> None:
123+
vector = aggregator(stationary_matrix)
124+
norm = vector.norm().item()
125+
assert norm < 8e-02
126+
127+
@staticmethod
128+
def _assert_non_stationarity_property(
129+
aggregator: Aggregator,
130+
non_stationary_matrix: Tensor,
131+
) -> None:
132+
vector = aggregator(non_stationary_matrix)
133+
norm = vector.norm().item()
134+
assert norm > 1e-03
135+
136+
137+
class StrongStationarityProperty(StationarityProperty):
138+
139+
@classmethod
140+
@mark.parametrize("stationary_matrix", strong_stationary_matrices)
141+
def test_stationarity_property(
142+
cls,
143+
aggregator: Aggregator,
144+
stationary_matrix: Tensor,
145+
):
146+
super(StrongStationarityProperty, cls)._assert_stationarity_property(
147+
aggregator, stationary_matrix
148+
)
149+
150+
@classmethod
151+
@mark.parametrize("non_stationary_matrix", weak_stationary_matrices + matrices)
152+
def test_non_stationarity_property(
153+
cls,
154+
aggregator: Aggregator,
155+
non_stationary_matrix: Tensor,
156+
):
157+
super(StrongStationarityProperty, cls)._assert_non_stationarity_property(
158+
aggregator, non_stationary_matrix
159+
)
160+
161+
162+
class WeakStationarityProperty(StationarityProperty):
163+
164+
@classmethod
165+
@mark.parametrize("stationary_matrix", strong_stationary_matrices + weak_stationary_matrices)
166+
def test_stationarity_property(
167+
cls,
168+
aggregator: Aggregator,
169+
stationary_matrix: Tensor,
170+
):
171+
super(WeakStationarityProperty, cls)._assert_stationarity_property(
172+
aggregator, stationary_matrix
173+
)
174+
175+
@classmethod
176+
@mark.parametrize("non_stationary_matrix", matrices)
177+
def test_non_stationarity_property(
178+
cls,
179+
aggregator: Aggregator,
180+
non_stationary_matrix: Tensor,
181+
):
182+
super(WeakStationarityProperty, cls)._assert_non_stationarity_property(
183+
aggregator, non_stationary_matrix
184+
)

tests/unit/aggregation/test_mgda.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
ExpectedStructureProperty,
1010
NonConflictingProperty,
1111
PermutationInvarianceProperty,
12+
WeakStationarityProperty,
1213
)
1314

1415

1516
@mark.parametrize("aggregator", [MGDA()])
16-
class TestMGDA(ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty):
17+
class TestMGDA(
18+
ExpectedStructureProperty,
19+
NonConflictingProperty,
20+
PermutationInvarianceProperty,
21+
WeakStationarityProperty,
22+
):
1723
pass
1824

1925

tests/unit/aggregation/test_upgrad.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
LinearUnderScalingProperty,
99
NonConflictingProperty,
1010
PermutationInvarianceProperty,
11+
StrongStationarityProperty,
1112
)
1213

1314

@@ -17,6 +18,7 @@ class TestUPGrad(
1718
NonConflictingProperty,
1819
PermutationInvarianceProperty,
1920
LinearUnderScalingProperty,
21+
StrongStationarityProperty,
2022
):
2123
pass
2224

0 commit comments

Comments
 (0)