From 3b842cf218360b5be5c67f70789edab5b13d48c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Mar 2025 13:22:10 +0100 Subject: [PATCH 1/6] Add missing TrimmedMean tests --- tests/unit/aggregation/test_trimmed_mean.py | 39 ++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index 609561ae7..d463bf114 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -1,5 +1,9 @@ -from pytest import mark +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises from torch import Tensor +from unit._utils import ExceptionContext from torchjd.aggregation import Aggregator, TrimmedMean @@ -22,6 +26,39 @@ def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Te cls._assert_permutation_invariance_property(aggregator, matrix) +@mark.parametrize( + ["trim_number", "expectation"], + [ + (-5, raises(ValueError)), + (-1, raises(ValueError)), + (0, does_not_raise()), + (1, does_not_raise()), + (5, does_not_raise()), + ], +) +def test_trim_number_check(trim_number: int, expectation: ExceptionContext): + with expectation: + _ = TrimmedMean(trim_number=trim_number) + + +@mark.parametrize( + ["n_rows", "trim_number", "expectation"], + [ + (1, 0, does_not_raise()), + (1, 1, raises(ValueError)), + (10, 0, does_not_raise()), + (10, 4, does_not_raise()), + (10, 5, raises(ValueError)), + ], +) +def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext): + matrix = torch.ones([n_rows, 5]) + aggregator = TrimmedMean(trim_number=trim_number) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): aggregator = TrimmedMean(trim_number=2) assert repr(aggregator) == "TrimmedMean(trim_number=2)" From 34dcd58587b26587fd4b8f157036e48d722c0b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Mar 2025 20:42:19 +0100 Subject: [PATCH 2/6] Add missing Krum tests --- tests/unit/aggregation/test_krum.py | 58 ++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 6c1850f6c..e215a77de 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -1,5 +1,9 @@ -from pytest import mark +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises from torch import Tensor +from unit._utils import ExceptionContext from torchjd.aggregation import Krum @@ -17,6 +21,58 @@ def test_expected_structure_property(cls, aggregator: Krum, matrix: Tensor): cls._assert_expected_structure_property(aggregator, matrix) +@mark.parametrize( + ["n_byzantine", "expectation"], + [ + (-5, raises(ValueError)), + (-1, raises(ValueError)), + (0, does_not_raise()), + (1, does_not_raise()), + (5, does_not_raise()), + ], +) +def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext): + with expectation: + _ = Krum(n_byzantine=n_byzantine, n_selected=1) + + +@mark.parametrize( + ["n_selected", "expectation"], + [ + (-5, raises(ValueError)), + (-1, raises(ValueError)), + (0, raises(ValueError)), + (1, does_not_raise()), + (5, does_not_raise()), + ], +) +def test_n_selected_check(n_selected: int, expectation: ExceptionContext): + with expectation: + _ = Krum(n_byzantine=1, n_selected=n_selected) + + +@mark.parametrize( + ["n_byzantine", "n_selected", "n_rows", "expectation"], + [ + (1, 1, 3, raises(ValueError)), + (1, 1, 4, does_not_raise()), + (1, 4, 4, does_not_raise()), + (12, 4, 14, raises(ValueError)), + (12, 4, 15, does_not_raise()), + (12, 15, 15, does_not_raise()), + (12, 16, 15, raises(ValueError)), + ], +) +def test_matrix_shape_check( + n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext +): + aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected) + matrix = torch.ones([n_rows, 5]) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): A = Krum(n_byzantine=1, n_selected=2) assert repr(A) == "Krum(n_byzantine=1, n_selected=2)" From a3106c9d6e2f2cabe1a5acb73a93835740f8d1b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Mar 2025 20:53:13 +0100 Subject: [PATCH 3/6] Add missing Constant tests --- tests/unit/aggregation/test_constant.py | 47 ++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index 0b3d2ac9c..e1dca61a8 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -1,6 +1,9 @@ +from contextlib import nullcontext as does_not_raise + import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor +from unit._utils import ExceptionContext from torchjd.aggregation import Constant @@ -41,6 +44,48 @@ def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor cls._assert_linear_under_scaling_property(aggregator, matrix) +@mark.parametrize( + ["weights_shape", "expectation"], + [ + ([], raises(ValueError)), + ([0], does_not_raise()), + ([1], does_not_raise()), + ([10], does_not_raise()), + ([0, 0], raises(ValueError)), + ([0, 1], raises(ValueError)), + ([1, 1], raises(ValueError)), + ([1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1, 1], raises(ValueError)), + ], +) +def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext): + weights = torch.ones(weights_shape) + with expectation: + _ = Constant(weights=weights) + + +@mark.parametrize( + ["weights_shape", "n_rows", "expectation"], + [ + ([0], 0, does_not_raise()), + ([1], 1, does_not_raise()), + ([5], 5, does_not_raise()), + ([0], 1, raises(ValueError)), + ([1], 0, raises(ValueError)), + ([4], 5, raises(ValueError)), + ([5], 4, raises(ValueError)), + ], +) +def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext): + matrix = torch.ones([n_rows, 5]) + weights = torch.ones(weights_shape) + aggregator = Constant(weights) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) assert repr(A) == "Constant(weights=tensor([1., 2.]))" From 774668ceb6687ded937d44c3b664887a33fa188d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Mar 2025 20:56:06 +0100 Subject: [PATCH 4/6] Add missing CAGrad tests --- tests/unit/aggregation/test_cagrad.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index 739682c7c..a0315d048 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -1,6 +1,9 @@ -from pytest import mark +from contextlib import nullcontext as does_not_raise + +from pytest import mark, raises from torch import Tensor from torch.testing import assert_close +from unit._utils import ExceptionContext from torchjd.aggregation import CAGrad, Mean @@ -33,6 +36,21 @@ def test_equivalence_mean(matrix: Tensor): assert_close(result, expected) +@mark.parametrize( + ["c", "expectation"], + [ + (-5.0, raises(ValueError)), + (-1.0, raises(ValueError)), + (0.0, does_not_raise()), + (1.0, does_not_raise()), + (50.0, does_not_raise()), + ], +) +def test_c_check(c: float, expectation: ExceptionContext): + with expectation: + _ = CAGrad(c=c) + + def test_representations(): A = CAGrad(c=0.5, norm_eps=0.0001) assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)" From de661bf045e2c730228b4435f4d0f78bf280164c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 23 Mar 2025 11:29:48 +0100 Subject: [PATCH 5/6] Add missing GradDrop tests --- tests/unit/aggregation/test_graddrop.py | 47 ++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index e1ec46011..1197f2302 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -1,5 +1,8 @@ +from contextlib import nullcontext as does_not_raise + import torch -from pytest import mark +from pytest import mark, raises +from unit._utils import ExceptionContext from torchjd.aggregation import GradDrop @@ -11,6 +14,48 @@ class TestGradDrop(ExpectedStructureProperty): pass +@mark.parametrize( + ["leak_shape", "expectation"], + [ + ([], raises(ValueError)), + ([0], does_not_raise()), + ([1], does_not_raise()), + ([10], does_not_raise()), + ([0, 0], raises(ValueError)), + ([0, 1], raises(ValueError)), + ([1, 1], raises(ValueError)), + ([1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1], raises(ValueError)), + ([1, 1, 1, 1, 1], raises(ValueError)), + ], +) +def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext): + leak = torch.ones(leak_shape) + with expectation: + _ = GradDrop(leak=leak) + + +@mark.parametrize( + ["leak_shape", "n_rows", "expectation"], + [ + ([0], 0, does_not_raise()), + ([1], 1, does_not_raise()), + ([5], 5, does_not_raise()), + ([0], 1, raises(ValueError)), + ([1], 0, raises(ValueError)), + ([4], 5, raises(ValueError)), + ([5], 4, raises(ValueError)), + ], +) +def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext): + matrix = torch.ones([n_rows, 5]) + leak = torch.ones(leak_shape) + aggregator = GradDrop(leak=leak) + + with expectation: + _ = aggregator(matrix) + + def test_representations(): A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu")) assert repr(A) == "GradDrop(leak=tensor([0., 1.]))" From ff8af5aa3856c22d704376d42b04d9897ccbd91d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 23 Mar 2025 11:37:06 +0100 Subject: [PATCH 6/6] Add missing (base) Aggregator test --- tests/unit/aggregation/test_base.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/unit/aggregation/test_base.py b/tests/unit/aggregation/test_base.py index 20f1613b6..52cadca43 100644 --- a/tests/unit/aggregation/test_base.py +++ b/tests/unit/aggregation/test_base.py @@ -21,3 +21,19 @@ def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext): with expectation: Aggregator._check_is_matrix(torch.randn(shape)) + + +@mark.parametrize( + ["value", "expectation"], + [ + (0.0, does_not_raise()), + (torch.nan, raises(ValueError)), + (torch.inf, raises(ValueError)), + (-torch.inf, raises(ValueError)), + ], +) +def test_check_is_finite(value: float, expectation: ExceptionContext): + matrix = torch.ones([5, 5]) + matrix[1, 2] = value + with expectation: + Aggregator._check_is_finite(matrix)