Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/unit/aggregation/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,13 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
matrices = [_generate_matrix(m, n, rank) for m, n, rank in _matrix_dimension_triples]
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
zero_matrices = [torch.zeros([m, n]) for m, n in _zero_matrices_shapes]
matrices_2_plus_rows = [matrix for matrix in matrices + zero_matrices if matrix.shape[0] >= 2]
scaled_matrices_2_plus_rows = [
matrix for matrix in scaled_matrices + zero_matrices if matrix.shape[0] >= 2
]
strong_stationary_matrices = [
_generate_strong_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
]
weak_stationary_matrices = strong_stationary_matrices + [
_generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
]
typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices

scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchjd.aggregation import CAGrad, Mean

from ._inputs import matrices, strong_stationary_matrices
from ._inputs import typical_matrices
from ._property_testers import ExpectedStructureProperty, NonConflictingProperty


Expand All @@ -20,7 +20,7 @@ class TestCAGradNonConflicting(NonConflictingProperty):
pass


@mark.parametrize("matrix", strong_stationary_matrices + matrices)
@mark.parametrize("matrix", typical_matrices)
def test_equivalence_mean(matrix: Tensor):
"""Tests that CAGrad is equivalent to Mean when c=0."""

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchjd.aggregation import Constant

from ._inputs import matrices, scaled_matrices, strong_stationary_matrices, zero_matrices
from ._inputs import scaled_matrices, typical_matrices
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty

# The weights must be a vector of length equal to the number of rows in the matrix that it will be
Expand All @@ -19,10 +19,10 @@ def _make_aggregator(matrix: Tensor) -> Constant:
return Constant(weights)


_matrices_1 = scaled_matrices + zero_matrices
_matrices_1 = scaled_matrices + typical_matrices
_aggregators_1 = [_make_aggregator(matrix) for matrix in _matrices_1]

_matrices_2 = matrices + strong_stationary_matrices
_matrices_2 = typical_matrices
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torchjd.aggregation import Krum

from ._inputs import scaled_matrices_2_plus_rows
from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows
from ._property_testers import ExpectedStructureProperty


Expand All @@ -12,7 +12,7 @@ class TestKrum(ExpectedStructureProperty):
# Override the parametrization of some property-testing methods because Krum only works on
# matrices with >= 2 rows.
@classmethod
@mark.parametrize("matrix", scaled_matrices_2_plus_rows)
@mark.parametrize("matrix", scaled_matrices_2_plus_rows + typical_matrices_2_plus_rows)
def test_expected_structure_property(cls, aggregator: Krum, matrix: Tensor):
cls._assert_expected_structure_property(aggregator, matrix)

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/aggregation/test_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torchjd.aggregation import Aggregator, TrimmedMean

from ._inputs import matrices_2_plus_rows, scaled_matrices_2_plus_rows
from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows
from ._property_testers import ExpectedStructureProperty, PermutationInvarianceProperty


Expand All @@ -12,12 +12,12 @@ class TestTrimmedMean(ExpectedStructureProperty, PermutationInvarianceProperty):
# Override the parametrization of some property-testing methods because `TrimmedMean` with
# `trim_number=1` only works on matrices with >= 2 rows.
@classmethod
@mark.parametrize("matrix", scaled_matrices_2_plus_rows)
@mark.parametrize("matrix", scaled_matrices_2_plus_rows + typical_matrices_2_plus_rows)
def test_expected_structure_property(cls, aggregator: TrimmedMean, matrix: Tensor):
cls._assert_expected_structure_property(aggregator, matrix)

@classmethod
@mark.parametrize("matrix", matrices_2_plus_rows)
@mark.parametrize("matrix", typical_matrices_2_plus_rows)
def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_permutation_invariance_property(aggregator, matrix)

Expand Down