Skip to content

Commit 13345f4

Browse files
committed
Remove support for random aggregators in LinearUnderScalingProperty
- Remove seed resetting - Remove LinearUnderScalingProperty from TestPCGrad and TestRGW
1 parent e651972 commit 13345f4

3 files changed

Lines changed: 4 additions & 10 deletions

File tree

tests/unit/aggregation/_property_testers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,8 @@ def _assert_linear_under_scaling_property(
102102
alpha = torch.rand([])
103103
beta = torch.rand([])
104104

105-
seed = torch.seed()
106-
107105
x1 = aggregator(torch.diag(c1) @ matrix)
108-
109-
torch.manual_seed(seed)
110106
x2 = aggregator(torch.diag(c2) @ matrix)
111-
112-
torch.manual_seed(seed)
113107
x = aggregator(torch.diag(alpha * c1 + beta * c2) @ matrix)
114108
expected = alpha * x1 + beta * x2
115109

tests/unit/aggregation/test_pcgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from torchjd.aggregation.sum import _SumWeighting
88
from torchjd.aggregation.upgrad import _UPGradWrapper
99

10-
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty
10+
from ._property_testers import ExpectedStructureProperty
1111

1212

1313
@mark.parametrize("aggregator", [PCGrad()])
14-
class TestPCGrad(ExpectedStructureProperty, LinearUnderScalingProperty):
14+
class TestPCGrad(ExpectedStructureProperty):
1515
pass
1616

1717

tests/unit/aggregation/test_random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from torchjd.aggregation import Random
44

5-
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty
5+
from ._property_testers import ExpectedStructureProperty
66

77

88
@mark.parametrize("aggregator", [Random()])
9-
class TestRGW(ExpectedStructureProperty, LinearUnderScalingProperty):
9+
class TestRGW(ExpectedStructureProperty):
1010
pass
1111

1212

0 commit comments

Comments
 (0)