diff --git a/tests/unit/aggregation/test_base.py b/tests/unit/aggregation/test_base.py index 20f1613b6..52cadca43 100644 --- a/tests/unit/aggregation/test_base.py +++ b/tests/unit/aggregation/test_base.py @@ -21,3 +21,19 @@ def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext): with expectation: Aggregator._check_is_matrix(torch.randn(shape)) + + +@mark.parametrize( + ["value", "expectation"], + [ + (0.0, does_not_raise()), + (torch.nan, raises(ValueError)), + (torch.inf, raises(ValueError)), + (-torch.inf, raises(ValueError)), + ], +) +def test_check_is_finite(value: float, expectation: ExceptionContext): + matrix = torch.ones([5, 5]) + matrix[1, 2] = value + with expectation: + Aggregator._check_is_finite(matrix) diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index 739682c7c..a0315d048 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -1,6 +1,9 @@ -from pytest import mark +from contextlib import nullcontext as does_not_raise + +from pytest import mark, raises from torch import Tensor from torch.testing import assert_close +from unit._utils import ExceptionContext from torchjd.aggregation import CAGrad, Mean @@ -33,6 +36,21 @@ def test_equivalence_mean(matrix: Tensor): assert_close(result, expected) +@mark.parametrize( + ["c", "expectation"], + [ + (-5.0, raises(ValueError)), + (-1.0, raises(ValueError)), + (0.0, does_not_raise()), + (1.0, does_not_raise()), + (50.0, does_not_raise()), + ], +) +def test_c_check(c: float, expectation: ExceptionContext): + with expectation: + _ = CAGrad(c=c) + + def test_representations(): A = CAGrad(c=0.5, norm_eps=0.0001) assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)" diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index 0b3d2ac9c..e1dca61a8 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -1,6 +1,9 @@ +from contextlib import nullcontext as does_not_raise + import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor +from unit._utils import ExceptionContext from torchjd.aggregation import Constant @@ -41,6 +44,48 @@ def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor cls._assert_linear_under_scaling_property(aggregator, matrix) +@mark.parametrize( + ["weights_shape", "expectation"], + [ + ([], raises(ValueError)), + ([0], does_not_raise()), + ([1], does_not_raise()), + ([10], does_not_raise()), + ([0, 0], raises(ValueError)), + ([0, 1], raises(ValueError)), + ([1, 1], raises(ValueError)), + ([1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1, 1], raises(ValueError)), + ], +) +def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext): + weights = torch.ones(weights_shape) + with expectation: + _ = Constant(weights=weights) + + +@mark.parametrize( + ["weights_shape", "n_rows", "expectation"], + [ + ([0], 0, does_not_raise()), + ([1], 1, does_not_raise()), + ([5], 5, does_not_raise()), + ([0], 1, raises(ValueError)), + ([1], 0, raises(ValueError)), + ([4], 5, raises(ValueError)), + ([5], 4, raises(ValueError)), + ], +) +def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext): + matrix = torch.ones([n_rows, 5]) + weights = torch.ones(weights_shape) + aggregator = Constant(weights) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) assert repr(A) == "Constant(weights=tensor([1., 2.]))" diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index e1ec46011..1197f2302 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -1,5 +1,8 @@ +from contextlib import nullcontext as does_not_raise + import torch -from pytest import mark +from pytest import mark, raises +from unit._utils import ExceptionContext from torchjd.aggregation import GradDrop @@ -11,6 +14,48 @@ class TestGradDrop(ExpectedStructureProperty): pass +@mark.parametrize( + ["leak_shape", "expectation"], + [ + ([], raises(ValueError)), + ([0], does_not_raise()), + ([1], does_not_raise()), + ([10], does_not_raise()), + ([0, 0], raises(ValueError)), + ([0, 1], raises(ValueError)), + ([1, 1], raises(ValueError)), + ([1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1, 1], raises(ValueError)), + ], +) +def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext): + leak = torch.ones(leak_shape) + with expectation: + _ = GradDrop(leak=leak) + + +@mark.parametrize( + ["leak_shape", "n_rows", "expectation"], + [ + ([0], 0, does_not_raise()), + ([1], 1, does_not_raise()), + ([5], 5, does_not_raise()), + ([0], 1, raises(ValueError)), + ([1], 0, raises(ValueError)), + ([4], 5, raises(ValueError)), + ([5], 4, raises(ValueError)), + ], +) +def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext): + matrix = torch.ones([n_rows, 5]) + leak = torch.ones(leak_shape) + aggregator = GradDrop(leak=leak) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu")) assert repr(A) == "GradDrop(leak=tensor([0., 1.]))" diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 6c1850f6c..e215a77de 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -1,5 +1,9 @@ -from pytest import mark +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises from torch import Tensor +from unit._utils import ExceptionContext from torchjd.aggregation import Krum @@ -17,6 +21,58 @@ def test_expected_structure_property(cls, aggregator: Krum, matrix: Tensor): cls._assert_expected_structure_property(aggregator, matrix) +@mark.parametrize( + ["n_byzantine", "expectation"], + [ + (-5, raises(ValueError)), + (-1, raises(ValueError)), + (0, does_not_raise()), + (1, does_not_raise()), + (5, does_not_raise()), + ], +) +def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext): + with expectation: + _ = Krum(n_byzantine=n_byzantine, n_selected=1) + + +@mark.parametrize( + ["n_selected", "expectation"], + [ + (-5, raises(ValueError)), + (-1, raises(ValueError)), + (0, raises(ValueError)), + (1, does_not_raise()), + (5, does_not_raise()), + ], +) +def test_n_selected_check(n_selected: int, expectation: ExceptionContext): + with expectation: + _ = Krum(n_byzantine=1, n_selected=n_selected) + + +@mark.parametrize( + ["n_byzantine", "n_selected", "n_rows", "expectation"], + [ + (1, 1, 3, raises(ValueError)), + (1, 1, 4, does_not_raise()), + (1, 4, 4, does_not_raise()), + (12, 4, 14, raises(ValueError)), + (12, 4, 15, does_not_raise()), + (12, 15, 15, does_not_raise()), + (12, 16, 15, raises(ValueError)), + ], +) +def test_matrix_shape_check( + n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext +): + aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected) + matrix = torch.ones([n_rows, 5]) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): A = Krum(n_byzantine=1, n_selected=2) assert repr(A) == "Krum(n_byzantine=1, n_selected=2)" diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index 609561ae7..d463bf114 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -1,5 +1,9 @@ -from pytest import mark +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises from torch import Tensor +from unit._utils import ExceptionContext from torchjd.aggregation import Aggregator, TrimmedMean @@ -22,6 +26,39 @@ def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Te cls._assert_permutation_invariance_property(aggregator, matrix) +@mark.parametrize( + ["trim_number", "expectation"], + [ + (-5, raises(ValueError)), + (-1, raises(ValueError)), + (0, does_not_raise()), + (1, does_not_raise()), + (5, does_not_raise()), + ], +) +def test_trim_number_check(trim_number: int, expectation: ExceptionContext): + with expectation: + _ = TrimmedMean(trim_number=trim_number) + + +@mark.parametrize( + ["n_rows", "trim_number", "expectation"], + [ + (1, 0, does_not_raise()), + (1, 1, raises(ValueError)), + (10, 0, does_not_raise()), + (10, 4, does_not_raise()), + (10, 5, raises(ValueError)), + ], +) +def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext): + matrix = torch.ones([n_rows, 5]) + aggregator = TrimmedMean(trim_number=trim_number) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): aggregator = TrimmedMean(trim_number=2) assert repr(aggregator) == "TrimmedMean(trim_number=2)"