-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_upgrad.py
More file actions
69 lines (51 loc) · 2.44 KB
/
test_upgrad.py
File metadata and controls
69 lines (51 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from pytest import mark
from torch import Tensor
from tests.utils.tensors import ones_
from torchjd.aggregation import UPGrad
from ._asserts import (
assert_expected_structure,
assert_linear_under_scaling,
assert_non_conflicting,
assert_non_differentiable,
assert_permutation_invariant,
assert_strongly_stationary,
)
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices
scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices]
typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices]
non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices]
requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))]
@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
def test_expected_structure(aggregator: UPGrad, matrix: Tensor):
assert_expected_structure(aggregator, matrix)
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_non_conflicting(aggregator: UPGrad, matrix: Tensor):
assert_non_conflicting(aggregator, matrix, atol=3e-04, rtol=3e-04)
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor):
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=4e-07, rtol=4e-07)
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor):
assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=3e-02, rtol=3e-02)
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor):
assert_strongly_stationary(aggregator, matrix, threshold=5e-03)
@mark.parametrize(["aggregator", "matrix"], requires_grad_pairs)
def test_non_differentiable(aggregator: UPGrad, matrix: Tensor):
assert_non_differentiable(aggregator, matrix)
def test_representations():
A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
assert str(A) == "UPGrad"
A = UPGrad(
pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"),
norm_eps=0.0001,
reg_eps=0.0001,
solver="quadprog",
)
assert (
repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, "
"solver='quadprog')"
)
assert str(A) == "UPGrad([1., 2., 3.])"