-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_init.py
More file actions
76 lines (53 loc) · 1.98 KB
/
test_init.py
File metadata and controls
76 lines (53 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from pytest import raises
from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import tensor_
from torchjd.autojac._transform import Init, RequirementError
def test_single_input():
"""
Tests that when there is a single key to initialize, the Init transform creates a TensorDict
whose value is a tensor full of ones, of the same shape as its key.
"""
key = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
input = {}
init = Init({key})
output = init(input)
expected_output = {key: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])}
assert_tensor_dicts_are_close(output, expected_output)
def test_multiple_inputs():
"""
Tests that when there are several keys to initialize, the Init transform creates a TensorDict
whose values are tensors full of ones, of the same shape as their corresponding keys.
"""
key1 = tensor_([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
key2 = tensor_([1.0, 3.0, 5.0])
input = {}
init = Init({key1, key2})
output = init(input)
expected = {
key1: tensor_([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
key2: tensor_([1.0, 1.0, 1.0]),
}
assert_tensor_dicts_are_close(output, expected)
def test_conjunction_of_inits_is_init():
"""
Tests that the conjunction of 2 Init transforms is equivalent to a single Init transform with
multiple keys.
"""
x1 = tensor_(5.0)
x2 = tensor_(6.0)
input = {}
init1 = Init({x1})
init2 = Init({x2})
conjunction_of_inits = init1 | init2
init = Init({x1, x2})
output = conjunction_of_inits(input)
expected_output = init(input)
assert_tensor_dicts_are_close(output, expected_output)
def test_check_keys():
"""Tests that the `check_keys` method works correctly: the input_keys should be empty."""
key = tensor_([1.0])
init = Init({key})
output_keys = init.check_keys(set())
assert output_keys == {key}
with raises(RequirementError):
init.check_keys({key})