Skip to content

Commit aba36d8

Browse files
committed
Implements the stationarity properties
1 parent ec6caa7 commit aba36d8

File tree

4 files changed

+97
-5
lines changed

4 files changed

+97
-5
lines changed

tests/unit/aggregation/_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _generate_weak_stationary_matrix(n_rows: int, n_cols: int) -> Tensor:
114114
_generate_strong_stationary_matrix(n_rows, n_cols)
115115
for n_rows, n_cols in _stationary_matrices_shapes
116116
]
117-
weak_stationary_matrices = strong_stationary_matrices + [
117+
weak_stationary_matrices = [
118118
_generate_weak_stationary_matrix(n_rows, n_cols)
119119
for n_rows, n_cols in _stationary_matrices_shapes
120120
]

tests/unit/aggregation/_property_testers.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
from torchjd.aggregation import Aggregator
77

8-
from ._inputs import matrices, scaled_matrices, weak_stationary_matrices, zero_matrices
8+
from ._inputs import (
9+
matrices,
10+
scaled_matrices,
11+
strong_stationary_matrices,
12+
weak_stationary_matrices,
13+
zero_matrices,
14+
)
915

1016

1117
class ExpectedStructureProperty:
@@ -34,7 +40,7 @@ class NonConflictingProperty:
3440
"""
3541

3642
@classmethod
37-
@mark.parametrize("matrix", weak_stationary_matrices + matrices)
43+
@mark.parametrize("matrix", strong_stationary_matrices + weak_stationary_matrices + matrices)
3844
def test_non_conflicting_property(
3945
cls,
4046
aggregator: Aggregator,
@@ -80,3 +86,77 @@ def _assert_permutation_invariance_property(aggregator: Aggregator, matrix: Tens
8086
def _permute_randomly(matrix: Tensor) -> Tensor:
8187
row_permutation = torch.randperm(matrix.size(dim=0))
8288
return matrix[row_permutation]
89+
90+
91+
class StationarityProperty:
92+
"""
93+
This class tests empirically that a given `Aggregator` satisfies the stationarity property.
94+
"""
95+
96+
@staticmethod
97+
def _assert_stationarity_property(
98+
aggregator: Aggregator,
99+
stationary_matrix: Tensor,
100+
) -> None:
101+
vector = aggregator(stationary_matrix)
102+
norm = vector.norm().item()
103+
assert norm < 8e-02
104+
105+
@staticmethod
106+
def _assert_non_stationarity_property(
107+
aggregator: Aggregator,
108+
non_stationary_matrix: Tensor,
109+
) -> None:
110+
vector = aggregator(non_stationary_matrix)
111+
norm = vector.norm().item()
112+
assert norm > 1e-03
113+
114+
115+
class StrongStationarityProperty(StationarityProperty):
116+
117+
@classmethod
118+
@mark.parametrize("stationary_matrix", strong_stationary_matrices)
119+
def test_stationarity_property(
120+
cls,
121+
aggregator: Aggregator,
122+
stationary_matrix: Tensor,
123+
):
124+
super(StrongStationarityProperty, cls)._assert_stationarity_property(
125+
aggregator, stationary_matrix
126+
)
127+
128+
@classmethod
129+
@mark.parametrize("non_stationary_matrix", weak_stationary_matrices + matrices)
130+
def test_non_stationarity_property(
131+
cls,
132+
aggregator: Aggregator,
133+
non_stationary_matrix: Tensor,
134+
):
135+
super(StrongStationarityProperty, cls)._assert_non_stationarity_property(
136+
aggregator, non_stationary_matrix
137+
)
138+
139+
140+
class WeakStationarityProperty(StationarityProperty):
141+
142+
@classmethod
143+
@mark.parametrize("stationary_matrix", strong_stationary_matrices + weak_stationary_matrices)
144+
def test_stationarity_property(
145+
cls,
146+
aggregator: Aggregator,
147+
stationary_matrix: Tensor,
148+
):
149+
super(WeakStationarityProperty, cls)._assert_stationarity_property(
150+
aggregator, stationary_matrix
151+
)
152+
153+
@classmethod
154+
@mark.parametrize("non_stationary_matrix", matrices)
155+
def test_non_stationarity_property(
156+
cls,
157+
aggregator: Aggregator,
158+
non_stationary_matrix: Tensor,
159+
):
160+
super(WeakStationarityProperty, cls)._assert_non_stationarity_property(
161+
aggregator, non_stationary_matrix
162+
)

tests/unit/aggregation/test_mgda.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
ExpectedStructureProperty,
1010
NonConflictingProperty,
1111
PermutationInvarianceProperty,
12+
WeakStationarityProperty,
1213
)
1314

1415

1516
@mark.parametrize("aggregator", [MGDA()])
16-
class TestMGDA(ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty):
17+
class TestMGDA(
18+
ExpectedStructureProperty,
19+
NonConflictingProperty,
20+
PermutationInvarianceProperty,
21+
WeakStationarityProperty,
22+
):
1723
pass
1824

1925

tests/unit/aggregation/test_upgrad.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
ExpectedStructureProperty,
88
NonConflictingProperty,
99
PermutationInvarianceProperty,
10+
StrongStationarityProperty,
1011
)
1112

1213

1314
@mark.parametrize("aggregator", [UPGrad()])
14-
class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty):
15+
class TestUPGrad(
16+
ExpectedStructureProperty,
17+
NonConflictingProperty,
18+
PermutationInvarianceProperty,
19+
StrongStationarityProperty,
20+
):
1521
pass
1622

1723

0 commit comments

Comments
 (0)