File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 33
44from torchjd .aggregation import AlignedMTL
55
6- from ._property_testers import ExpectedStructureProperty , PermutationInvarianceProperty
6+ from ._property_testers import (
7+ ExpectedStructureProperty ,
8+ PermutationInvarianceProperty ,
9+ StrongStationarityProperty ,
10+ )
711
812
913@mark .parametrize ("aggregator" , [AlignedMTL ()])
10- class TestAlignedMTL (ExpectedStructureProperty , PermutationInvarianceProperty ):
14+ class TestAlignedMTL (
15+ ExpectedStructureProperty , PermutationInvarianceProperty , StrongStationarityProperty
16+ ):
1117 pass
1218
1319
Original file line number Diff line number Diff line change 88from torchjd .aggregation import CAGrad , Mean
99
1010from ._inputs import typical_matrices
11- from ._property_testers import ExpectedStructureProperty , NonConflictingProperty
11+ from ._property_testers import (
12+ ExpectedStructureProperty ,
13+ NonConflictingProperty ,
14+ StrongStationarityProperty ,
15+ )
1216
1317
1418@mark .parametrize ("aggregator" , [CAGrad (c = 0.5 )])
15- class TestCAGrad (ExpectedStructureProperty ):
19+ class TestCAGrad (ExpectedStructureProperty , StrongStationarityProperty ):
1620 pass
1721
1822
Original file line number Diff line number Diff line change 99 ExpectedStructureProperty ,
1010 NonConflictingProperty ,
1111 PermutationInvarianceProperty ,
12+ StrongStationarityProperty ,
1213)
1314
1415
1516@mark .parametrize ("aggregator" , [MGDA ()])
16- class TestMGDA (ExpectedStructureProperty , NonConflictingProperty , PermutationInvarianceProperty ):
17+ class TestMGDA (
18+ ExpectedStructureProperty ,
19+ NonConflictingProperty ,
20+ PermutationInvarianceProperty ,
21+ StrongStationarityProperty ,
22+ ):
1723 pass
1824
1925
You can’t perform that action at this time.
0 commit comments