Skip to content

Commit 480705b

Browse files
authored
test(aggregation): Add tests for aggregator checks (#264)
* Add missing TrimmedMean check tests * Add missing Krum check tests * Add missing Constant check tests * Add missing CAGrad check test * Add missing GradDrop check tests * Add missing (base) Aggregator check test
1 parent 66ca700 commit 480705b

File tree

6 files changed

+222
-5
lines changed

6 files changed

+222
-5
lines changed

tests/unit/aggregation/test_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,19 @@
2121
def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext):
2222
with expectation:
2323
Aggregator._check_is_matrix(torch.randn(shape))
24+
25+
26+
@mark.parametrize(
27+
["value", "expectation"],
28+
[
29+
(0.0, does_not_raise()),
30+
(torch.nan, raises(ValueError)),
31+
(torch.inf, raises(ValueError)),
32+
(-torch.inf, raises(ValueError)),
33+
],
34+
)
35+
def test_check_is_finite(value: float, expectation: ExceptionContext):
36+
matrix = torch.ones([5, 5])
37+
matrix[1, 2] = value
38+
with expectation:
39+
Aggregator._check_is_finite(matrix)

tests/unit/aggregation/test_cagrad.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from pytest import mark
1+
from contextlib import nullcontext as does_not_raise
2+
3+
from pytest import mark, raises
24
from torch import Tensor
35
from torch.testing import assert_close
6+
from unit._utils import ExceptionContext
47

58
from torchjd.aggregation import CAGrad, Mean
69

@@ -33,6 +36,21 @@ def test_equivalence_mean(matrix: Tensor):
3336
assert_close(result, expected)
3437

3538

39+
@mark.parametrize(
40+
["c", "expectation"],
41+
[
42+
(-5.0, raises(ValueError)),
43+
(-1.0, raises(ValueError)),
44+
(0.0, does_not_raise()),
45+
(1.0, does_not_raise()),
46+
(50.0, does_not_raise()),
47+
],
48+
)
49+
def test_c_check(c: float, expectation: ExceptionContext):
50+
with expectation:
51+
_ = CAGrad(c=c)
52+
53+
3654
def test_representations():
3755
A = CAGrad(c=0.5, norm_eps=0.0001)
3856
assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)"

tests/unit/aggregation/test_constant.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from contextlib import nullcontext as does_not_raise
2+
13
import torch
2-
from pytest import mark
4+
from pytest import mark, raises
35
from torch import Tensor
6+
from unit._utils import ExceptionContext
47

58
from torchjd.aggregation import Constant
69

@@ -41,6 +44,48 @@ def test_linear_under_scaling_property(cls, aggregator: Constant, matrix: Tensor
4144
cls._assert_linear_under_scaling_property(aggregator, matrix)
4245

4346

47+
@mark.parametrize(
48+
["weights_shape", "expectation"],
49+
[
50+
([], raises(ValueError)),
51+
([0], does_not_raise()),
52+
([1], does_not_raise()),
53+
([10], does_not_raise()),
54+
([0, 0], raises(ValueError)),
55+
([0, 1], raises(ValueError)),
56+
([1, 1], raises(ValueError)),
57+
([1, 1, 1], raises(ValueError)),
58+
([1, 1, 1, 1], raises(ValueError)),
59+
([1, 1, 1, 1, 1], raises(ValueError)),
60+
],
61+
)
62+
def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext):
63+
weights = torch.ones(weights_shape)
64+
with expectation:
65+
_ = Constant(weights=weights)
66+
67+
68+
@mark.parametrize(
69+
["weights_shape", "n_rows", "expectation"],
70+
[
71+
([0], 0, does_not_raise()),
72+
([1], 1, does_not_raise()),
73+
([5], 5, does_not_raise()),
74+
([0], 1, raises(ValueError)),
75+
([1], 0, raises(ValueError)),
76+
([4], 5, raises(ValueError)),
77+
([5], 4, raises(ValueError)),
78+
],
79+
)
80+
def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext):
81+
matrix = torch.ones([n_rows, 5])
82+
weights = torch.ones(weights_shape)
83+
aggregator = Constant(weights)
84+
85+
with expectation:
86+
_ = aggregator(matrix)
87+
88+
4489
def test_representations():
4590
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
4691
assert repr(A) == "Constant(weights=tensor([1., 2.]))"

tests/unit/aggregation/test_graddrop.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from contextlib import nullcontext as does_not_raise
2+
13
import torch
2-
from pytest import mark
4+
from pytest import mark, raises
5+
from unit._utils import ExceptionContext
36

47
from torchjd.aggregation import GradDrop
58

@@ -11,6 +14,48 @@ class TestGradDrop(ExpectedStructureProperty):
1114
pass
1215

1316

17+
@mark.parametrize(
18+
["leak_shape", "expectation"],
19+
[
20+
([], raises(ValueError)),
21+
([0], does_not_raise()),
22+
([1], does_not_raise()),
23+
([10], does_not_raise()),
24+
([0, 0], raises(ValueError)),
25+
([0, 1], raises(ValueError)),
26+
([1, 1], raises(ValueError)),
27+
([1, 1, 1], raises(ValueError)),
28+
([1, 1, 1, 1], raises(ValueError)),
29+
([1, 1, 1, 1, 1], raises(ValueError)),
30+
],
31+
)
32+
def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext):
33+
leak = torch.ones(leak_shape)
34+
with expectation:
35+
_ = GradDrop(leak=leak)
36+
37+
38+
@mark.parametrize(
39+
["leak_shape", "n_rows", "expectation"],
40+
[
41+
([0], 0, does_not_raise()),
42+
([1], 1, does_not_raise()),
43+
([5], 5, does_not_raise()),
44+
([0], 1, raises(ValueError)),
45+
([1], 0, raises(ValueError)),
46+
([4], 5, raises(ValueError)),
47+
([5], 4, raises(ValueError)),
48+
],
49+
)
50+
def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext):
51+
matrix = torch.ones([n_rows, 5])
52+
leak = torch.ones(leak_shape)
53+
aggregator = GradDrop(leak=leak)
54+
55+
with expectation:
56+
_ = aggregator(matrix)
57+
58+
1459
def test_representations():
1560
A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu"))
1661
assert repr(A) == "GradDrop(leak=tensor([0., 1.]))"

tests/unit/aggregation/test_krum.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from pytest import mark
1+
from contextlib import nullcontext as does_not_raise
2+
3+
import torch
4+
from pytest import mark, raises
25
from torch import Tensor
6+
from unit._utils import ExceptionContext
37

48
from torchjd.aggregation import Krum
59

@@ -17,6 +21,58 @@ def test_expected_structure_property(cls, aggregator: Krum, matrix: Tensor):
1721
cls._assert_expected_structure_property(aggregator, matrix)
1822

1923

24+
@mark.parametrize(
25+
["n_byzantine", "expectation"],
26+
[
27+
(-5, raises(ValueError)),
28+
(-1, raises(ValueError)),
29+
(0, does_not_raise()),
30+
(1, does_not_raise()),
31+
(5, does_not_raise()),
32+
],
33+
)
34+
def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext):
35+
with expectation:
36+
_ = Krum(n_byzantine=n_byzantine, n_selected=1)
37+
38+
39+
@mark.parametrize(
40+
["n_selected", "expectation"],
41+
[
42+
(-5, raises(ValueError)),
43+
(-1, raises(ValueError)),
44+
(0, raises(ValueError)),
45+
(1, does_not_raise()),
46+
(5, does_not_raise()),
47+
],
48+
)
49+
def test_n_selected_check(n_selected: int, expectation: ExceptionContext):
50+
with expectation:
51+
_ = Krum(n_byzantine=1, n_selected=n_selected)
52+
53+
54+
@mark.parametrize(
55+
["n_byzantine", "n_selected", "n_rows", "expectation"],
56+
[
57+
(1, 1, 3, raises(ValueError)),
58+
(1, 1, 4, does_not_raise()),
59+
(1, 4, 4, does_not_raise()),
60+
(12, 4, 14, raises(ValueError)),
61+
(12, 4, 15, does_not_raise()),
62+
(12, 15, 15, does_not_raise()),
63+
(12, 16, 15, raises(ValueError)),
64+
],
65+
)
66+
def test_matrix_shape_check(
67+
n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext
68+
):
69+
aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected)
70+
matrix = torch.ones([n_rows, 5])
71+
72+
with expectation:
73+
_ = aggregator(matrix)
74+
75+
2076
def test_representations():
2177
A = Krum(n_byzantine=1, n_selected=2)
2278
assert repr(A) == "Krum(n_byzantine=1, n_selected=2)"

tests/unit/aggregation/test_trimmed_mean.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
from pytest import mark
1+
from contextlib import nullcontext as does_not_raise
2+
3+
import torch
4+
from pytest import mark, raises
25
from torch import Tensor
6+
from unit._utils import ExceptionContext
37

48
from torchjd.aggregation import Aggregator, TrimmedMean
59

@@ -22,6 +26,39 @@ def test_permutation_invariance_property(cls, aggregator: Aggregator, matrix: Te
2226
cls._assert_permutation_invariance_property(aggregator, matrix)
2327

2428

29+
@mark.parametrize(
30+
["trim_number", "expectation"],
31+
[
32+
(-5, raises(ValueError)),
33+
(-1, raises(ValueError)),
34+
(0, does_not_raise()),
35+
(1, does_not_raise()),
36+
(5, does_not_raise()),
37+
],
38+
)
39+
def test_trim_number_check(trim_number: int, expectation: ExceptionContext):
40+
with expectation:
41+
_ = TrimmedMean(trim_number=trim_number)
42+
43+
44+
@mark.parametrize(
45+
["n_rows", "trim_number", "expectation"],
46+
[
47+
(1, 0, does_not_raise()),
48+
(1, 1, raises(ValueError)),
49+
(10, 0, does_not_raise()),
50+
(10, 4, does_not_raise()),
51+
(10, 5, raises(ValueError)),
52+
],
53+
)
54+
def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext):
55+
matrix = torch.ones([n_rows, 5])
56+
aggregator = TrimmedMean(trim_number=trim_number)
57+
58+
with expectation:
59+
_ = aggregator(matrix)
60+
61+
2562
def test_representations():
2663
aggregator = TrimmedMean(trim_number=2)
2764
assert repr(aggregator) == "TrimmedMean(trim_number=2)"

0 commit comments

Comments
 (0)