|
1 | 1 | import torch |
| 2 | +from numpy.ma.testutils import assert_allclose |
2 | 3 | from pytest import raises |
3 | 4 | from torch import Tensor |
4 | 5 | from torch.testing import assert_close |
5 | 6 | from utils.tensors import rand_, randperm_ |
6 | 7 |
|
7 | | -from torchjd.aggregation import Aggregator |
| 8 | +from torchjd.aggregation import Aggregator, Stateful |
8 | 9 | from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError |
9 | 10 |
|
10 | 11 |
|
@@ -110,3 +111,33 @@ def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None: |
110 | 111 | vector = aggregator(matrix) |
111 | 112 | with raises(NonDifferentiableError): |
112 | 113 | vector.backward(torch.ones_like(vector)) |
| 114 | + |
| 115 | + |
| 116 | +def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None: |
| 117 | + """ |
| 118 | + Test that a given `Aggregator` is stateful. Specifically: |
| 119 | + - For a fixed state, the aggregator is determinist on the matrix |
| 120 | + - The reset method and the constructor both set the state to the initial state |
| 121 | + """ |
| 122 | + |
| 123 | + assert isinstance(aggregator, Stateful) |
| 124 | + |
| 125 | + first_pair = (aggregator(matrix), aggregator(matrix)) |
| 126 | + aggregator.reset() |
| 127 | + second_pair = (aggregator(matrix), aggregator(matrix)) |
| 128 | + |
| 129 | + assert_allclose(first_pair[0], second_pair[0], atol=0.0, rtol=0.0) |
| 130 | + assert_allclose(first_pair[1], second_pair[1], atol=0.0, rtol=0.0) |
| 131 | + |
| 132 | + |
| 133 | +def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None: |
| 134 | + """ |
| 135 | + Test that a given `Aggregator` is stateless. Specifically, it must be deterministic. |
| 136 | + """ |
| 137 | + |
| 138 | + assert not isinstance(aggregator, Stateful) |
| 139 | + |
| 140 | + first = aggregator(matrix) |
| 141 | + second = aggregator(matrix) |
| 142 | + |
| 143 | + assert_allclose(first, second, atol=0.0, rtol=0.0) |
0 commit comments