Skip to content

Commit bd2476d

Browse files
committed
Add strong stationarity property to constant, mean, random and sum
1 parent 6e84fec commit bd2476d

5 files changed

Lines changed: 33 additions & 15 deletions

File tree

tests/unit/aggregation/_property_testers.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from torchjd.aggregation import Aggregator
77

8-
from ._inputs import (
9-
non_strong_stationary_matrices,
10-
scaled_matrices,
11-
typical_matrices,
12-
)
8+
from ._inputs import non_strong_stationary_matrices, scaled_matrices, typical_matrices
139

1410

1511
class ExpectedStructureProperty:
@@ -122,9 +118,7 @@ def test_stationarity_property(cls, aggregator: Aggregator, matrix: Tensor):
122118
cls._assert_stationarity_property(aggregator, matrix)
123119

124120
@staticmethod
125-
def _assert_stationarity_property(
126-
aggregator: Aggregator, matrix: Tensor
127-
) -> None:
121+
def _assert_stationarity_property(aggregator: Aggregator, matrix: Tensor) -> None:
128122
vector = aggregator(matrix)
129123
norm = vector.norm().item()
130124
assert norm > 1e-03

tests/unit/aggregation/test_constant.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
from torchjd.aggregation import Constant
99

10-
from ._inputs import scaled_matrices, typical_matrices
11-
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty
10+
from ._inputs import non_strong_stationary_matrices, scaled_matrices, typical_matrices
11+
from ._property_testers import (
12+
ExpectedStructureProperty,
13+
LinearUnderScalingProperty,
14+
StrongStationarityProperty,
15+
)
1216

1317
# The weights must be a vector of length equal to the number of rows in the matrix that it will be
1418
# applied to. Thus, each `Constant` instance is specific to matrices of a given number of rows. To
@@ -28,8 +32,13 @@ def _make_aggregator(matrix: Tensor) -> Constant:
2832
_matrices_2 = typical_matrices
2933
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]
3034

35+
_matrices_3 = non_strong_stationary_matrices
36+
_aggregators_3 = [_make_aggregator(matrix) for matrix in _matrices_3]
37+
3138

32-
class TestConstant(ExpectedStructureProperty, LinearUnderScalingProperty):
39+
class TestConstant(
40+
ExpectedStructureProperty, LinearUnderScalingProperty, StrongStationarityProperty
41+
):
3342
# Override the parametrization of `test_expected_structure_property` to make the test use the
3443
# right aggregator with each matrix.
3544

@@ -43,6 +52,11 @@ def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor):
4352
def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor):
4453
cls._assert_linear_under_scaling_property(aggregator, matrix)
4554

55+
@classmethod
56+
@mark.parametrize(["aggregator", "matrix"], zip(_aggregators_3, _matrices_3))
57+
def test_stationarity_property(cls, aggregator: Constant, non_stationary_matrix: Tensor):
58+
cls._assert_stationarity_property(aggregator, non_stationary_matrix)
59+
4660

4761
@mark.parametrize(
4862
["weights_shape", "expectation"],

tests/unit/aggregation/test_mean.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
ExpectedStructureProperty,
77
LinearUnderScalingProperty,
88
PermutationInvarianceProperty,
9+
StrongStationarityProperty,
910
)
1011

1112

1213
@mark.parametrize("aggregator", [Mean()])
1314
class TestMean(
14-
ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty
15+
ExpectedStructureProperty,
16+
PermutationInvarianceProperty,
17+
LinearUnderScalingProperty,
18+
StrongStationarityProperty,
1519
):
1620
pass
1721

tests/unit/aggregation/test_random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from torchjd.aggregation import Random
44

5-
from ._property_testers import ExpectedStructureProperty
5+
from ._property_testers import ExpectedStructureProperty, StrongStationarityProperty
66

77

88
@mark.parametrize("aggregator", [Random()])
9-
class TestRandom(ExpectedStructureProperty):
9+
class TestRandom(ExpectedStructureProperty, StrongStationarityProperty):
1010
pass
1111

1212

tests/unit/aggregation/test_sum.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
ExpectedStructureProperty,
77
LinearUnderScalingProperty,
88
PermutationInvarianceProperty,
9+
StrongStationarityProperty,
910
)
1011

1112

1213
@mark.parametrize("aggregator", [Sum()])
13-
class TestSum(ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty):
14+
class TestSum(
15+
ExpectedStructureProperty,
16+
PermutationInvarianceProperty,
17+
LinearUnderScalingProperty,
18+
StrongStationarityProperty,
19+
):
1420
pass
1521

1622

0 commit comments

Comments
 (0)