Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions tests/unit/aggregation/_property_testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from torchjd.aggregation import Aggregator

from ._inputs import matrices, scaled_matrices, weak_stationary_matrices, zero_matrices
from ._inputs import (
matrices,
scaled_matrices,
strong_stationary_matrices,
weak_stationary_matrices,
zero_matrices,
)


class ExpectedStructureProperty:
Expand Down Expand Up @@ -35,18 +41,11 @@ class NonConflictingProperty:

@classmethod
@mark.parametrize("matrix", weak_stationary_matrices + matrices)
def test_non_conflicting_property(
cls,
aggregator: Aggregator,
matrix: Tensor,
):
def test_non_conflicting_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_non_conflicting_property(aggregator, matrix)

@staticmethod
def _assert_non_conflicting_property(
aggregator: Aggregator,
matrix: Tensor,
) -> None:
def _assert_non_conflicting_property(aggregator: Aggregator, matrix: Tensor) -> None:
vector = aggregator(matrix)
output_direction = matrix @ vector
positive_directions = output_direction[output_direction >= 0]
Expand Down Expand Up @@ -80,3 +79,32 @@ def _assert_permutation_invariance_property(aggregator: Aggregator, matrix: Tens
def _permute_randomly(matrix: Tensor) -> Tensor:
row_permutation = torch.randperm(matrix.size(dim=0))
return matrix[row_permutation]


class LinearUnderScalingProperty:
"""
This class tests empirically that a given `Aggregator` satisfies the linear under scaling
property.
"""

@classmethod
@mark.parametrize("matrix", strong_stationary_matrices + matrices)
def test_linear_under_scaling_property(cls, aggregator: Aggregator, matrix: Tensor):
cls._assert_linear_under_scaling_property(aggregator, matrix)

@staticmethod
def _assert_linear_under_scaling_property(
aggregator: Aggregator,
matrix: Tensor,
) -> None:
c1 = torch.rand(matrix.shape[0])
c2 = torch.rand(matrix.shape[0])
alpha = torch.rand([])
beta = torch.rand([])

x1 = aggregator(torch.diag(c1) @ matrix)
x2 = aggregator(torch.diag(c2) @ matrix)
x = aggregator(torch.diag(alpha * c1 + beta * c2) @ matrix)
expected = alpha * x1 + beta * x2

assert_close(x, expected, atol=8e-03, rtol=0)
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from torchjd.aggregation import ConFIG

from ._property_testers import ExpectedStructureProperty
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty


# For some reason, some permutation-invariance property tests fail with the pinv-based
# implementation.
@mark.parametrize("aggregator", [ConFIG()])
class TestConfig(ExpectedStructureProperty):
class TestConfig(ExpectedStructureProperty, LinearUnderScalingProperty):
pass


Expand Down
9 changes: 7 additions & 2 deletions tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchjd.aggregation import Constant

from ._inputs import matrices, scaled_matrices, strong_stationary_matrices, zero_matrices
from ._property_testers import ExpectedStructureProperty
from ._property_testers import ExpectedStructureProperty, LinearUnderScalingProperty

# The weights must be a vector of length equal to the number of rows in the matrix that it will be
# applied to. Thus, each `Constant` instance is specific to matrices of a given number of rows. To
Expand All @@ -26,7 +26,7 @@ def _make_aggregator(matrix: Tensor) -> Constant:
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]


class TestConstant(ExpectedStructureProperty):
class TestConstant(ExpectedStructureProperty, LinearUnderScalingProperty):
# Override the parametrization of `test_expected_structure_property` to make the test use the
# right aggregator with each matrix.

Expand All @@ -35,6 +35,11 @@ class TestConstant(ExpectedStructureProperty):
def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor):
cls._assert_expected_structure_property(aggregator, matrix)

@classmethod
@mark.parametrize(["aggregator", "matrix"], zip(_aggregators_2, _matrices_2))
def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor):
cls._assert_linear_under_scaling_property(aggregator, matrix)


def test_representations():
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/aggregation/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

from torchjd.aggregation import Mean

from ._property_testers import ExpectedStructureProperty, PermutationInvarianceProperty
from ._property_testers import (
ExpectedStructureProperty,
LinearUnderScalingProperty,
PermutationInvarianceProperty,
)


@mark.parametrize("aggregator", [Mean()])
class TestMean(ExpectedStructureProperty, PermutationInvarianceProperty):
class TestMean(
ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty
):
pass


Expand Down
8 changes: 6 additions & 2 deletions tests/unit/aggregation/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from torchjd.aggregation import Sum

from ._property_testers import ExpectedStructureProperty, PermutationInvarianceProperty
from ._property_testers import (
ExpectedStructureProperty,
LinearUnderScalingProperty,
PermutationInvarianceProperty,
)


@mark.parametrize("aggregator", [Sum()])
class TestSum(ExpectedStructureProperty, PermutationInvarianceProperty):
class TestSum(ExpectedStructureProperty, PermutationInvarianceProperty, LinearUnderScalingProperty):
pass


Expand Down
8 changes: 7 additions & 1 deletion tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

from ._property_testers import (
ExpectedStructureProperty,
LinearUnderScalingProperty,
NonConflictingProperty,
PermutationInvarianceProperty,
)


@mark.parametrize("aggregator", [UPGrad()])
class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationInvarianceProperty):
class TestUPGrad(
ExpectedStructureProperty,
NonConflictingProperty,
PermutationInvarianceProperty,
LinearUnderScalingProperty,
):
pass


Expand Down