-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_constant.py
More file actions
42 lines (29 loc) · 1.58 KB
/
test_constant.py
File metadata and controls
42 lines (29 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from pytest import mark
from torch import Tensor
from torchjd.aggregation import Constant
from ._inputs import matrices, scaled_matrices, strong_stationary_matrices, zero_matrices
from ._property_testers import ExpectedStructureProperty
# The weights must be a vector of length equal to the number of rows in the matrix that it will be
# applied to. Thus, each `Constant` instance is specific to matrices of a given number of rows. To
# test properties on all possible matrices, we have to create one `Constant` with the right number
# of weights for each matrix.
def _make_aggregator(matrix: Tensor) -> Constant:
n_rows = matrix.shape[0]
weights = torch.tensor([1.0 / n_rows] * n_rows)
return Constant(weights)
_matrices_1 = scaled_matrices + zero_matrices
_aggregators_1 = [_make_aggregator(matrix) for matrix in _matrices_1]
_matrices_2 = matrices + strong_stationary_matrices
_aggregators_2 = [_make_aggregator(matrix) for matrix in _matrices_2]
class TestConstant(ExpectedStructureProperty):
# Override the parametrization of `test_expected_structure_property` to make the test use the
# right aggregator with each matrix.
@classmethod
@mark.parametrize(["aggregator", "matrix"], zip(_aggregators_1, _matrices_1))
def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor):
cls._assert_expected_structure_property(aggregator, matrix)
def test_representations():
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
assert repr(A) == "Constant(weights=tensor([1., 2.]))"
assert str(A) == "Constant([1., 2.])"