Skip to content

Commit 56ea02a

Browse files
committed
Add test of stateful/stateless (all must be one or the other).
1 parent b291573 commit 56ea02a

18 files changed

+139
-10
lines changed

tests/unit/aggregation/_asserts.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import torch
2+
from numpy.ma.testutils import assert_allclose
23
from pytest import raises
34
from torch import Tensor
45
from torch.testing import assert_close
56
from utils.tensors import rand_, randperm_
67

7-
from torchjd.aggregation import Aggregator
8+
from torchjd.aggregation import Aggregator, Stateful
89
from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError
910

1011

@@ -110,3 +111,33 @@ def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None:
110111
vector = aggregator(matrix)
111112
with raises(NonDifferentiableError):
112113
vector.backward(torch.ones_like(vector))
114+
115+
116+
def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None:
117+
"""
118+
Test that a given `Aggregator` is stateful. Specifically:
119+
- For a fixed state, the aggregator is determinist on the matrix
120+
- The reset method and the constructor both set the state to the initial state
121+
"""
122+
123+
assert isinstance(aggregator, Stateful)
124+
125+
first_pair = (aggregator(matrix), aggregator(matrix))
126+
aggregator.reset()
127+
second_pair = (aggregator(matrix), aggregator(matrix))
128+
129+
assert_allclose(first_pair[0], second_pair[0], atol=0.0, rtol=0.0)
130+
assert_allclose(first_pair[1], second_pair[1], atol=0.0, rtol=0.0)
131+
132+
133+
def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None:
134+
"""
135+
Test that a given `Aggregator` is stateless. Specifically, it must be deterministic.
136+
"""
137+
138+
assert not isinstance(aggregator, Stateful)
139+
140+
first = aggregator(matrix)
141+
second = aggregator(matrix)
142+
143+
assert_allclose(first, second, atol=0.0, rtol=0.0)

tests/unit/aggregation/test_aligned_mtl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torchjd.aggregation import AlignedMTL
77

8-
from ._asserts import assert_expected_structure, assert_permutation_invariant
8+
from ._asserts import assert_expected_structure, assert_permutation_invariant, assert_stateless
99
from ._inputs import scaled_matrices, typical_matrices
1010

1111
aggregators = [
@@ -28,6 +28,11 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor) -> None:
2828
assert_permutation_invariant(aggregator, matrix)
2929

3030

31+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
32+
def test_stateless(aggregator: AlignedMTL, matrix: Tensor) -> None:
33+
assert_stateless(aggregator, matrix)
34+
35+
3136
def test_representations() -> None:
3237
A = AlignedMTL(pref_vector=None)
3338
assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')"

tests/unit/aggregation/test_cagrad.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
pytest.skip("CAGrad dependencies not installed", allow_module_level=True)
1414

15-
from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
15+
from ._asserts import (
16+
assert_expected_structure,
17+
assert_non_conflicting,
18+
assert_non_differentiable,
19+
assert_stateless,
20+
)
1621
from ._inputs import scaled_matrices, typical_matrices
1722

1823
scaled_pairs = [(CAGrad(c=0.5), matrix) for matrix in scaled_matrices]
@@ -38,6 +43,11 @@ def test_non_conflicting(aggregator: CAGrad, matrix: Tensor) -> None:
3843
assert_non_conflicting(aggregator, matrix)
3944

4045

46+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
47+
def test_stateless(aggregator: CAGrad, matrix: Tensor) -> None:
48+
assert_stateless(aggregator, matrix)
49+
50+
4151
@mark.parametrize(
4252
["c", "expectation"],
4353
[

tests/unit/aggregation/test_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
assert_linear_under_scaling,
1111
assert_non_differentiable,
1212
assert_permutation_invariant,
13+
assert_stateless,
1314
)
1415
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
1516

@@ -39,6 +40,11 @@ def test_non_differentiable(aggregator: ConFIG, matrix: Tensor) -> None:
3940
assert_non_differentiable(aggregator, matrix)
4041

4142

43+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
44+
def test_stateless(aggregator: ConFIG, matrix: Tensor) -> None:
45+
assert_stateless(aggregator, matrix)
46+
47+
4248
def test_representations() -> None:
4349
A = ConFIG()
4450
assert repr(A) == "ConFIG(pref_vector=None)"

tests/unit/aggregation/test_constant.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ._asserts import (
1212
assert_expected_structure,
1313
assert_linear_under_scaling,
14+
assert_stateless,
1415
assert_strongly_stationary,
1516
)
1617
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
@@ -42,6 +43,11 @@ def test_strongly_stationary(aggregator: Constant, matrix: Tensor) -> None:
4243
assert_strongly_stationary(aggregator, matrix)
4344

4445

46+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
47+
def test_stateless(aggregator: Constant, matrix: Tensor) -> None:
48+
assert_stateless(aggregator, matrix)
49+
50+
4551
@mark.parametrize(
4652
["weights_shape", "expectation"],
4753
[

tests/unit/aggregation/test_dualproj.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
assert_non_conflicting,
1111
assert_non_differentiable,
1212
assert_permutation_invariant,
13+
assert_stateless,
1314
assert_strongly_stationary,
1415
)
1516
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
@@ -45,6 +46,11 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None:
4546
assert_non_differentiable(aggregator, matrix)
4647

4748

49+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
50+
def test_stateless(aggregator: DualProj, matrix: Tensor) -> None:
51+
assert_stateless(aggregator, matrix)
52+
53+
4854
def test_representations() -> None:
4955
A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
5056
assert (

tests/unit/aggregation/test_graddrop.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from torchjd.aggregation import GradDrop
1111

12-
from ._asserts import assert_expected_structure, assert_non_differentiable
12+
from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful
1313
from ._inputs import scaled_matrices, typical_matrices
1414

1515
scaled_pairs = [(GradDrop(), matrix) for matrix in scaled_matrices]
@@ -27,6 +27,11 @@ def test_non_differentiable(aggregator: GradDrop, matrix: Tensor) -> None:
2727
assert_non_differentiable(aggregator, matrix)
2828

2929

30+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
31+
def test_stateful(aggregator: GradDrop, matrix: Tensor) -> None:
32+
assert_stateful(aggregator, matrix)
33+
34+
3035
@mark.parametrize(
3136
["leak_shape", "expectation"],
3237
[

tests/unit/aggregation/test_gradvac.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from torchjd.aggregation import GradVac, GradVacWeighting
88

9-
from ._asserts import assert_expected_structure, assert_non_differentiable
9+
from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful
1010
from ._inputs import scaled_matrices, typical_matrices, typical_matrices_2_plus_rows
1111

1212
scaled_pairs = [(GradVac(), m) for m in scaled_matrices]
@@ -104,6 +104,11 @@ def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None:
104104
assert_non_differentiable(aggregator, matrix)
105105

106106

107+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
108+
def test_stateful(aggregator: GradVac, matrix: Tensor) -> None:
109+
assert_stateful(aggregator, matrix)
110+
111+
107112
def test_weighting_beta_out_of_range() -> None:
108113
with raises(ValueError, match="beta"):
109114
GradVacWeighting(beta=-0.1)

tests/unit/aggregation/test_imtl_g.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
assert_expected_structure,
1010
assert_non_differentiable,
1111
assert_permutation_invariant,
12+
assert_stateless,
1213
)
1314
from ._inputs import scaled_matrices, typical_matrices
1415

@@ -32,6 +33,11 @@ def test_non_differentiable(aggregator: IMTLG, matrix: Tensor) -> None:
3233
assert_non_differentiable(aggregator, matrix)
3334

3435

36+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
37+
def test_stateless(aggregator: IMTLG, matrix: Tensor) -> None:
38+
assert_stateless(aggregator, matrix)
39+
40+
3541
def test_imtlg_zero() -> None:
3642
"""
3743
Tests that IMTLG correctly returns the 0 vector in the special case where input matrix only

tests/unit/aggregation/test_krum.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from torchjd.aggregation import Krum
99

10-
from ._asserts import assert_expected_structure
10+
from ._asserts import assert_expected_structure, assert_stateless
1111
from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows
1212

1313
scaled_pairs = [(Krum(n_byzantine=1), matrix) for matrix in scaled_matrices_2_plus_rows]
@@ -19,6 +19,11 @@ def test_expected_structure(aggregator: Krum, matrix: Tensor) -> None:
1919
assert_expected_structure(aggregator, matrix)
2020

2121

22+
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
23+
def test_stateless(aggregator: Krum, matrix: Tensor) -> None:
24+
assert_stateless(aggregator, matrix)
25+
26+
2227
@mark.parametrize(
2328
["n_byzantine", "expectation"],
2429
[

0 commit comments

Comments
 (0)