Skip to content

Commit 8914b16

Browse files
test(aggregation): Remove test_equivalence_mean in CAGrad (#278)
1 parent f227121 commit 8914b16

File tree

1 file changed

+1
-17
lines changed

1 file changed

+1
-17
lines changed

tests/unit/aggregation/test_cagrad.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from contextlib import nullcontext as does_not_raise
22

33
from pytest import mark, raises
4-
from torch import Tensor
5-
from torch.testing import assert_close
64
from unit._utils import ExceptionContext
75

8-
from torchjd.aggregation import CAGrad, Mean
6+
from torchjd.aggregation import CAGrad
97

10-
from ._inputs import typical_matrices
118
from ._property_testers import ExpectedStructureProperty, NonConflictingProperty
129

1310

@@ -23,19 +20,6 @@ class TestCAGradNonConflicting(NonConflictingProperty):
2320
pass
2421

2522

26-
@mark.parametrize("matrix", typical_matrices)
27-
def test_equivalence_mean(matrix: Tensor):
28-
"""Tests that CAGrad is equivalent to Mean when c=0."""
29-
30-
ca_grad = CAGrad(c=0.0)
31-
mean = Mean()
32-
33-
result = ca_grad(matrix)
34-
expected = mean(matrix)
35-
36-
assert_close(result, expected, atol=2e-1, rtol=0)
37-
38-
3923
@mark.parametrize(
4024
["c", "expectation"],
4125
[

0 commit comments

Comments
 (0)