55
66from torchjd .aggregation import Aggregator
77
8- from ._inputs import (
9- matrices ,
10- scaled_matrices ,
11- strong_stationary_matrices ,
12- weak_stationary_matrices ,
13- zero_matrices ,
14- )
8+ from ._inputs import scaled_matrices , typical_matrices
159
1610
1711class ExpectedStructureProperty :
@@ -23,7 +17,7 @@ class ExpectedStructureProperty:
2317 """
2418
2519 @classmethod
26- @mark .parametrize ("matrix" , scaled_matrices + zero_matrices )
20+ @mark .parametrize ("matrix" , scaled_matrices + typical_matrices )
2721 def test_expected_structure_property (cls , aggregator : Aggregator , matrix : Tensor ):
2822 cls ._assert_expected_structure_property (aggregator , matrix )
2923
@@ -40,7 +34,7 @@ class NonConflictingProperty:
4034 """
4135
4236 @classmethod
43- @mark .parametrize ("matrix" , weak_stationary_matrices + matrices )
37+ @mark .parametrize ("matrix" , typical_matrices )
4438 def test_non_conflicting_property (cls , aggregator : Aggregator , matrix : Tensor ):
4539 cls ._assert_non_conflicting_property (aggregator , matrix )
4640
@@ -61,7 +55,7 @@ class PermutationInvarianceProperty:
6155 N_PERMUTATIONS = 5
6256
6357 @classmethod
64- @mark .parametrize ("matrix" , matrices )
58+ @mark .parametrize ("matrix" , typical_matrices )
6559 def test_permutation_invariance_property (cls , aggregator : Aggregator , matrix : Tensor ):
6660 cls ._assert_permutation_invariance_property (aggregator , matrix )
6761
@@ -88,7 +82,7 @@ class LinearUnderScalingProperty:
8882 """
8983
9084 @classmethod
91- @mark .parametrize ("matrix" , strong_stationary_matrices + matrices )
85+ @mark .parametrize ("matrix" , typical_matrices )
9286 def test_linear_under_scaling_property (cls , aggregator : Aggregator , matrix : Tensor ):
9387 cls ._assert_linear_under_scaling_property (aggregator , matrix )
9488
0 commit comments