Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/unit/aggregation/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
(9, 11),
]

_scales = [0.0, 1e-10, 1.0, 1e3, 1e5, 1e10, 1e15]
_scales = [0.0, 1e-10, 1e3, 1e5, 1e10, 1e15]

# Fix seed to fix randomness of matrix generation
torch.manual_seed(0)
Expand All @@ -115,3 +115,4 @@ def _generate_semi_orthonormal_complement(Q: Tensor) -> Tensor:
weak_stationary_matrices = strong_stationary_matrices + [
_generate_weak_stationary_matrix(m, n) for m, n in _stationary_matrices_shapes
]
typical_matrices = zero_matrices + matrices + weak_stationary_matrices + strong_stationary_matrices
16 changes: 5 additions & 11 deletions tests/unit/aggregation/_property_testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@

from torchjd.aggregation import Aggregator

from ._inputs import (
matrices,
scaled_matrices,
strong_stationary_matrices,
weak_stationary_matrices,
zero_matrices,
)
from ._inputs import scaled_matrices, typical_matrices


class ExpectedStructureProperty:
Expand All @@ -23,7 +17,7 @@ class ExpectedStructureProperty:
"""

@classmethod
@mark.parametrize("matrix", scaled_matrices + zero_matrices)
@mark.parametrize("matrix", scaled_matrices + typical_matrices)
def test_expected_structure_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_expected_structure_property(aggregator, matrix)

Expand All @@ -40,7 +34,7 @@ class NonConflictingProperty:
"""

@classmethod
@mark.parametrize("matrix", weak_stationary_matrices + matrices)
@mark.parametrize("matrix", typical_matrices)
def test_non_conflicting_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_non_conflicting_property(aggregator, matrix)

Expand All @@ -61,7 +55,7 @@ class PermutationInvarianceProperty:
N_PERMUTATIONS = 5

@classmethod
@mark.parametrize("matrix", matrices)
@mark.parametrize("matrix", typical_matrices)
def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_permutation_invariance_property(aggregator, matrix)

Expand All @@ -88,7 +82,7 @@ class LinearUnderScalingProperty:
"""

@classmethod
@mark.parametrize("matrix", strong_stationary_matrices + matrices)
@mark.parametrize("matrix", typical_matrices)
def test_linear_under_scaling_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_linear_under_scaling_property(aggregator, matrix)

Expand Down
5 changes: 3 additions & 2 deletions tests/unit/aggregation/test_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from torchjd.aggregation import IMTLG

from ._property_testers import ExpectedStructureProperty, PermutationInvarianceProperty
from ._property_testers import ExpectedStructureProperty


# For some reason, a permutation-invariance property test fails on GPU
@mark.parametrize("aggregator", [IMTLG()])
class TestIMTLG(ExpectedStructureProperty, PermutationInvarianceProperty):
class TestIMTLG(ExpectedStructureProperty):
pass


Expand Down