Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/unit/aggregation/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,19 @@
def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext):
with expectation:
Aggregator._check_is_matrix(torch.randn(shape))


@mark.parametrize(
["value", "expectation"],
[
(0.0, does_not_raise()),
(torch.nan, raises(ValueError)),
(torch.inf, raises(ValueError)),
(-torch.inf, raises(ValueError)),
],
)
def test_check_is_finite(value: float, expectation: ExceptionContext):
matrix = torch.ones([5, 5])
matrix[1, 2] = value
with expectation:
Aggregator._check_is_finite(matrix)
20 changes: 19 additions & 1 deletion tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pytest import mark
from contextlib import nullcontext as does_not_raise

from pytest import mark, raises
from torch import Tensor
from torch.testing import assert_close
from unit._utils import ExceptionContext

from torchjd.aggregation import CAGrad, Mean

Expand Down Expand Up @@ -33,6 +36,21 @@ def test_equivalence_mean(matrix: Tensor):
assert_close(result, expected)


@mark.parametrize(
["c", "expectation"],
[
(-5.0, raises(ValueError)),
(-1.0, raises(ValueError)),
(0.0, does_not_raise()),
(1.0, does_not_raise()),
(50.0, does_not_raise()),
],
)
def test_c_check(c: float, expectation: ExceptionContext):
with expectation:
_ = CAGrad(c=c)


def test_representations():
A = CAGrad(c=0.5, norm_eps=0.0001)
assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)"
Expand Down
47 changes: 46 additions & 1 deletion tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from contextlib import nullcontext as does_not_raise

import torch
from pytest import mark
from pytest import mark, raises
from torch import Tensor
from unit._utils import ExceptionContext

from torchjd.aggregation import Constant

Expand Down Expand Up @@ -41,6 +44,48 @@ def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor
cls._assert_linear_under_scaling_property(aggregator, matrix)


@mark.parametrize(
["weights_shape", "expectation"],
[
([], raises(ValueError)),
([0], does_not_raise()),
([1], does_not_raise()),
([10], does_not_raise()),
([0, 0], raises(ValueError)),
([0, 1], raises(ValueError)),
([1, 1], raises(ValueError)),
([1, 1, 1], raises(ValueError)),
([1, 1, 1, 1], raises(ValueError)),
([1, 1, 1, 1, 1], raises(ValueError)),
],
)
def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext):
weights = torch.ones(weights_shape)
with expectation:
_ = Constant(weights=weights)


@mark.parametrize(
["weights_shape", "n_rows", "expectation"],
[
([0], 0, does_not_raise()),
([1], 1, does_not_raise()),
([5], 5, does_not_raise()),
([0], 1, raises(ValueError)),
([1], 0, raises(ValueError)),
([4], 5, raises(ValueError)),
([5], 4, raises(ValueError)),
],
)
def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext):
matrix = torch.ones([n_rows, 5])
weights = torch.ones(weights_shape)
aggregator = Constant(weights)

with expectation:
_ = aggregator(matrix)


def test_representations():
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
assert repr(A) == "Constant(weights=tensor([1., 2.]))"
Expand Down
47 changes: 46 additions & 1 deletion tests/unit/aggregation/test_graddrop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from contextlib import nullcontext as does_not_raise

import torch
from pytest import mark
from pytest import mark, raises
from unit._utils import ExceptionContext

from torchjd.aggregation import GradDrop

Expand All @@ -11,6 +14,48 @@ class TestGradDrop(ExpectedStructureProperty):
pass


@mark.parametrize(
["leak_shape", "expectation"],
[
([], raises(ValueError)),
([0], does_not_raise()),
([1], does_not_raise()),
([10], does_not_raise()),
([0, 0], raises(ValueError)),
([0, 1], raises(ValueError)),
([1, 1], raises(ValueError)),
([1, 1, 1], raises(ValueError)),
([1, 1, 1, 1], raises(ValueError)),
([1, 1, 1, 1, 1], raises(ValueError)),
],
)
def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext):
leak = torch.ones(leak_shape)
with expectation:
_ = GradDrop(leak=leak)


@mark.parametrize(
["leak_shape", "n_rows", "expectation"],
[
([0], 0, does_not_raise()),
([1], 1, does_not_raise()),
([5], 5, does_not_raise()),
([0], 1, raises(ValueError)),
([1], 0, raises(ValueError)),
([4], 5, raises(ValueError)),
([5], 4, raises(ValueError)),
],
)
def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext):
matrix = torch.ones([n_rows, 5])
leak = torch.ones(leak_shape)
aggregator = GradDrop(leak=leak)

with expectation:
_ = aggregator(matrix)


def test_representations():
A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu"))
assert repr(A) == "GradDrop(leak=tensor([0., 1.]))"
Expand Down
58 changes: 57 additions & 1 deletion tests/unit/aggregation/test_krum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from pytest import mark
from contextlib import nullcontext as does_not_raise

import torch
from pytest import mark, raises
from torch import Tensor
from unit._utils import ExceptionContext

from torchjd.aggregation import Krum

Expand All @@ -17,6 +21,58 @@ def test_expected_structure_property(cls, aggregator: Krum, matrix: Tensor):
cls._assert_expected_structure_property(aggregator, matrix)


@mark.parametrize(
["n_byzantine", "expectation"],
[
(-5, raises(ValueError)),
(-1, raises(ValueError)),
(0, does_not_raise()),
(1, does_not_raise()),
(5, does_not_raise()),
],
)
def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext):
with expectation:
_ = Krum(n_byzantine=n_byzantine, n_selected=1)


@mark.parametrize(
["n_selected", "expectation"],
[
(-5, raises(ValueError)),
(-1, raises(ValueError)),
(0, raises(ValueError)),
(1, does_not_raise()),
(5, does_not_raise()),
],
)
def test_n_selected_check(n_selected: int, expectation: ExceptionContext):
with expectation:
_ = Krum(n_byzantine=1, n_selected=n_selected)


@mark.parametrize(
["n_byzantine", "n_selected", "n_rows", "expectation"],
[
(1, 1, 3, raises(ValueError)),
(1, 1, 4, does_not_raise()),
(1, 4, 4, does_not_raise()),
(12, 4, 14, raises(ValueError)),
(12, 4, 15, does_not_raise()),
(12, 15, 15, does_not_raise()),
(12, 16, 15, raises(ValueError)),
],
)
def test_matrix_shape_check(
n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext
):
aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected)
matrix = torch.ones([n_rows, 5])

with expectation:
_ = aggregator(matrix)


def test_representations():
A = Krum(n_byzantine=1, n_selected=2)
assert repr(A) == "Krum(n_byzantine=1, n_selected=2)"
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/aggregation/test_trimmed_mean.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from pytest import mark
from contextlib import nullcontext as does_not_raise

import torch
from pytest import mark, raises
from torch import Tensor
from unit._utils import ExceptionContext

from torchjd.aggregation import Aggregator, TrimmedMean

Expand All @@ -22,6 +26,39 @@ def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Te
cls._assert_permutation_invariance_property(aggregator, matrix)


@mark.parametrize(
["trim_number", "expectation"],
[
(-5, raises(ValueError)),
(-1, raises(ValueError)),
(0, does_not_raise()),
(1, does_not_raise()),
(5, does_not_raise()),
],
)
def test_trim_number_check(trim_number: int, expectation: ExceptionContext):
with expectation:
_ = TrimmedMean(trim_number=trim_number)


@mark.parametrize(
["n_rows", "trim_number", "expectation"],
[
(1, 0, does_not_raise()),
(1, 1, raises(ValueError)),
(10, 0, does_not_raise()),
(10, 4, does_not_raise()),
(10, 5, raises(ValueError)),
],
)
def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext):
matrix = torch.ones([n_rows, 5])
aggregator = TrimmedMean(trim_number=trim_number)

with expectation:
_ = aggregator(matrix)


def test_representations():
aggregator = TrimmedMean(trim_number=2)
assert repr(aggregator) == "TrimmedMean(trim_number=2)"
Expand Down