Skip to content

Commit ee27d3d

Browse files
committed
Add failing SSProperty to AMTL CAGRAD and MGDA
1 parent e328378 commit ee27d3d

3 files changed

Lines changed: 21 additions & 5 deletions

File tree

tests/unit/aggregation/test_aligned_mtl.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33

44
from torchjd.aggregation import AlignedMTL
55

6-
from ._property_testers import ExpectedStructureProperty, PermutationInvarianceProperty
6+
from ._property_testers import (
7+
ExpectedStructureProperty,
8+
PermutationInvarianceProperty,
9+
StrongStationarityProperty,
10+
)
711

812

913
@mark.parametrize("aggregator", [AlignedMTL()])
10-
class TestAlignedMTL(ExpectedStructureProperty, PermutationInvarianceProperty):
14+
class TestAlignedMTL(
15+
ExpectedStructureProperty, PermutationInvarianceProperty, StrongStationarityProperty
16+
):
1117
pass
1218

1319

tests/unit/aggregation/test_cagrad.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
from torchjd.aggregation import CAGrad, Mean
99

1010
from ._inputs import typical_matrices
11-
from ._property_testers import ExpectedStructureProperty, NonConflictingProperty
11+
from ._property_testers import (
12+
ExpectedStructureProperty,
13+
NonConflictingProperty,
14+
StrongStationarityProperty,
15+
)
1216

1317

1418
@mark.parametrize("aggregator", [CAGrad(c=0.5)])
15-
class TestCAGrad(ExpectedStructureProperty):
19+
class TestCAGrad(ExpectedStructureProperty, StrongStationarityProperty):
1620
pass
1721

1822

tests/unit/aggregation/test_mgda.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
ExpectedStructureProperty,
1010
NonConflictingProperty,
1111
PermutationInvarianceProperty,
12+
StrongStationarityProperty,
1213
)
1314

1415

1516
@mark.parametrize("aggregator", [MGDA()])
16-
class TestMGDA(ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty):
17+
class TestMGDA(
18+
ExpectedStructureProperty,
19+
NonConflictingProperty,
20+
PermutationInvarianceProperty,
21+
StrongStationarityProperty,
22+
):
1723
pass
1824

1925

0 commit comments

Comments
 (0)