File tree Expand file tree Collapse file tree 1 file changed +1
-17
lines changed
Expand file tree Collapse file tree 1 file changed +1
-17
lines changed Original file line number Diff line number Diff line change 11from contextlib import nullcontext as does_not_raise
22
33from pytest import mark , raises
4- from torch import Tensor
5- from torch .testing import assert_close
64from unit ._utils import ExceptionContext
75
8- from torchjd .aggregation import CAGrad , Mean
6+ from torchjd .aggregation import CAGrad
97
10- from ._inputs import typical_matrices
118from ._property_testers import ExpectedStructureProperty , NonConflictingProperty
129
1310
@@ -23,19 +20,6 @@ class TestCAGradNonConflicting(NonConflictingProperty):
2320 pass
2421
2522
26- @mark .parametrize ("matrix" , typical_matrices )
27- def test_equivalence_mean (matrix : Tensor ):
28- """Tests that CAGrad is equivalent to Mean when c=0."""
29-
30- ca_grad = CAGrad (c = 0.0 )
31- mean = Mean ()
32-
33- result = ca_grad (matrix )
34- expected = mean (matrix )
35-
36- assert_close (result , expected , atol = 2e-1 , rtol = 0 )
37-
38-
3923@mark .parametrize (
4024 ["c" , "expectation" ],
4125 [
You can’t perform that action at this time.
0 commit comments