Skip to content

Commit c9bfd09

Browse files
committed
Fix parametrization for aggregators overriding matrix parametrization
1 parent 98897dd commit c9bfd09

5 files changed

Lines changed: 10 additions & 11 deletions

File tree

tests/unit/aggregation/_inputs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,5 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
113113
]
114114
typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices
115115

116-
matrices_2_plus_rows = [matrix for matrix in matrices + zero_matrices if matrix.shape[0] >= 2]
117116
scaled_matrices_2_plus_rows = [matrix for matrix in scaled_matrices if matrix.shape[0] >= 2]
118117
typical_matrices_2_plus_rows = [matrix for matrix in typical_matrices if matrix.shape[0] >= 2]

tests/unit/aggregation/test_cagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from torchjd.aggregation import CAGrad, Mean
66

7-
from ._inputs import matrices, strong_stationary_matrices
7+
from ._inputs import typical_matrices
88
from ._property_testers import ExpectedStructureProperty, NonConflictingProperty
99

1010

@@ -20,7 +20,7 @@ class TestCAGradNonConflicting(NonConflictingProperty):
2020
pass
2121

2222

23-
@mark.parametrize("matrix", strong_stationary_matrices + matrices)
23+
@mark.parametrize("matrix", typical_matrices)
2424
def test_equivalence_mean(matrix: Tensor):
2525
"""Tests that CAGrad is equivalent to Mean when c=0."""
2626

tests/unit/aggregation/test_constant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from torchjd.aggregation import Constant
66

7-
from ._inputs import matrices, scaled_matrices, strong_stationary_matrices, zero_matrices
7+
from ._inputs import scaled_matrices, typical_matrices
88
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty
99

1010
# The weights must be a vector of length equal to the number of rows in the matrix that it will be
@@ -19,10 +19,10 @@ def _make_aggregator(matrix: Tensor) -> Constant:
1919
return Constant(weights)
2020

2121

22-
_matrices_1 = scaled_matrices + zero_matrices
22+
_matrices_1 = scaled_matrices + typical_matrices
2323
_aggregators_1 = [_make_aggregator(matrix) for matrix in _matrices_1]
2424

25-
_matrices_2 = matrices + strong_stationary_matrices
25+
_matrices_2 = typical_matrices
2626
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]
2727

2828

tests/unit/aggregation/test_krum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from torchjd.aggregation import Krum
55

6-
from ._inputs import scaled_matrices_2_plus_rows
6+
from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows
77
from ._property_testers import ExpectedStructureProperty
88

99

@@ -12,7 +12,7 @@ class TestKrum(ExpectedStructureProperty):
1212
# Override the parametrization of some property-testing methods because Krum only works on
1313
# matrices with >= 2 rows.
1414
@classmethod
15-
@mark.parametrize("matrix", scaled_matrices_2_plus_rows)
15+
@mark.parametrize("matrix", scaled_matrices_2_plus_rows + typical_matrices_2_plus_rows)
1616
def test_expected_structure_property(cls, aggregator: Krum, matrix: Tensor):
1717
cls._assert_expected_structure_property(aggregator, matrix)
1818

tests/unit/aggregation/test_trimmed_mean.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from torchjd.aggregation import Aggregator, TrimmedMean
55

6-
from ._inputs import matrices_2_plus_rows, scaled_matrices_2_plus_rows
6+
from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows
77
from ._property_testers import ExpectedStructureProperty, PermutationInvarianceProperty
88

99

@@ -12,12 +12,12 @@ class TestTrimmedMean(ExpectedStructureProperty, PermutationInvarianceProperty):
1212
# Override the parametrization of some property-testing methods because `TrimmedMean` with
1313
# `trim_number=1` only works on matrices with >= 2 rows.
1414
@classmethod
15-
@mark.parametrize("matrix", scaled_matrices_2_plus_rows)
15+
@mark.parametrize("matrix", scaled_matrices_2_plus_rows + typical_matrices_2_plus_rows)
1616
def test_expected_structure_property(cls, aggregator: TrimmedMean, matrix: Tensor):
1717
cls._assert_expected_structure_property(aggregator, matrix)
1818

1919
@classmethod
20-
@mark.parametrize("matrix", matrices_2_plus_rows)
20+
@mark.parametrize("matrix", typical_matrices_2_plus_rows)
2121
def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Tensor):
2222
cls._assert_permutation_invariance_property(aggregator, matrix)
2323

0 commit comments

Comments
 (0)