Skip to content

Commit 90dc110

Browse files
committed
Uniformize inputs parametrizing the property testers
1 parent 7d192db commit 90dc110

1 file changed

Lines changed: 5 additions & 11 deletions

File tree

tests/unit/aggregation/_property_testers.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@
55

66
from torchjd.aggregation import Aggregator
77

8-
from ._inputs import (
9-
matrices,
10-
scaled_matrices,
11-
strong_stationary_matrices,
12-
weak_stationary_matrices,
13-
zero_matrices,
14-
)
8+
from ._inputs import scaled_matrices, typical_matrices
159

1610

1711
class ExpectedStructureProperty:
@@ -23,7 +17,7 @@ class ExpectedStructureProperty:
2317
"""
2418

2519
@classmethod
26-
@mark.parametrize("matrix", scaled_matrices + zero_matrices)
20+
@mark.parametrize("matrix", scaled_matrices + typical_matrices)
2721
def test_expected_structure_property(cls, aggregator: Aggregator, matrix: Tensor):
2822
cls._assert_expected_structure_property(aggregator, matrix)
2923

@@ -40,7 +34,7 @@ class NonConflictingProperty:
4034
"""
4135

4236
@classmethod
43-
@mark.parametrize("matrix", weak_stationary_matrices + matrices)
37+
@mark.parametrize("matrix", typical_matrices)
4438
def test_non_conflicting_property(cls, aggregator: Aggregator, matrix: Tensor):
4539
cls._assert_non_conflicting_property(aggregator, matrix)
4640

@@ -61,7 +55,7 @@ class PermutationInvarianceProperty:
6155
N_PERMUTATIONS = 5
6256

6357
@classmethod
64-
@mark.parametrize("matrix", matrices)
58+
@mark.parametrize("matrix", typical_matrices)
6559
def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Tensor):
6660
cls._assert_permutation_invariance_property(aggregator, matrix)
6761

@@ -88,7 +82,7 @@ class LinearUnderScalingProperty:
8882
"""
8983

9084
@classmethod
91-
@mark.parametrize("matrix", strong_stationary_matrices + matrices)
85+
@mark.parametrize("matrix", typical_matrices)
9286
def test_linear_under_scaling_property(cls, aggregator: Aggregator, matrix: Tensor):
9387
cls._assert_linear_under_scaling_property(aggregator, matrix)
9488

0 commit comments

Comments
 (0)