diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index eee7c1797..26c8b29f5 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -1,13 +1,10 @@ from contextlib import nullcontext as does_not_raise from pytest import mark, raises -from torch import Tensor -from torch.testing import assert_close from unit._utils import ExceptionContext -from torchjd.aggregation import CAGrad, Mean +from torchjd.aggregation import CAGrad -from ._inputs import typical_matrices from ._property_testers import ExpectedStructureProperty, NonConflictingProperty @@ -23,19 +20,6 @@ class TestCAGradNonConflicting(NonConflictingProperty): pass -@mark.parametrize("matrix", typical_matrices) -def test_equivalence_mean(matrix: Tensor): - """Tests that CAGrad is equivalent to Mean when c=0.""" - - ca_grad = CAGrad(c=0.0) - mean = Mean() - - result = ca_grad(matrix) - expected = mean(matrix) - - assert_close(result, expected, atol=2e-1, rtol=0) - - @mark.parametrize( ["c", "expectation"], [