-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_dualproj.py
More file actions
41 lines (34 loc) · 1.05 KB
/
test_dualproj.py
File metadata and controls
41 lines (34 loc) · 1.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from pytest import mark
from torchjd.aggregation import DualProj
from ._property_testers import (
ExpectedStructureProperty,
NonConflictingProperty,
PermutationInvarianceProperty,
StrongStationarityProperty,
)
@mark.parametrize("aggregator", [DualProj()])
class TestDualProj(
ExpectedStructureProperty,
NonConflictingProperty,
PermutationInvarianceProperty,
StrongStationarityProperty,
):
pass
def test_representations():
A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
assert (
repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
)
assert str(A) == "DualProj"
A = DualProj(
pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"),
norm_eps=0.0001,
reg_eps=0.0001,
solver="quadprog",
)
assert (
repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, "
"solver='quadprog')"
)
assert str(A) == "DualProj([1., 2., 3.])"