Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from contextlib import nullcontext as does_not_raise

from pytest import mark, raises
from torch import Tensor
from torch.testing import assert_close
from unit._utils import ExceptionContext

from torchjd.aggregation import CAGrad, Mean
from torchjd.aggregation import CAGrad

from ._inputs import typical_matrices
from ._property_testers import ExpectedStructureProperty, NonConflictingProperty


Expand All @@ -23,19 +20,6 @@ class TestCAGradNonConflicting(NonConflictingProperty):
pass


@mark.parametrize("matrix", typical_matrices)
def test_equivalence_mean(matrix: Tensor):
"""Tests that CAGrad is equivalent to Mean when c=0."""

ca_grad = CAGrad(c=0.0)
mean = Mean()

result = ca_grad(matrix)
expected = mean(matrix)

assert_close(result, expected, atol=2e-1, rtol=0)


@mark.parametrize(
["c", "expectation"],
[
Expand Down