-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_accumulate.py
More file actions
104 lines (74 loc) · 2.76 KB
/
test_accumulate.py
File metadata and controls
104 lines (74 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from pytest import mark, raises
from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import ones_, tensor_, zeros_
from torchjd.autojac._transform import Accumulate
def test_single_accumulation():
"""
Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run
once.
"""
key1 = zeros_([], requires_grad=True)
key2 = zeros_([1], requires_grad=True)
key3 = zeros_([2, 3], requires_grad=True)
value1 = ones_([])
value2 = ones_([1])
value3 = ones_([2, 3])
input = {key1: value1, key2: value2, key3: value3}
accumulate = Accumulate()
output = accumulate(input)
expected_output = {}
assert_tensor_dicts_are_close(output, expected_output)
grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}
expected_grads = {key1: value1, key2: value2, key3: value3}
assert_tensor_dicts_are_close(grads, expected_grads)
@mark.parametrize("iterations", [1, 2, 4, 10, 13])
def test_multiple_accumulation(iterations: int):
"""
Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run
`iterations` times.
"""
key1 = zeros_([], requires_grad=True)
key2 = zeros_([1], requires_grad=True)
key3 = zeros_([2, 3], requires_grad=True)
value1 = ones_([])
value2 = ones_([1])
value3 = ones_([2, 3])
input = {key1: value1, key2: value2, key3: value3}
accumulate = Accumulate()
for i in range(iterations):
accumulate(input)
grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}
expected_grads = {
key1: iterations * value1,
key2: iterations * value2,
key3: iterations * value3,
}
assert_tensor_dicts_are_close(grads, expected_grads)
def test_no_requires_grad_fails():
"""
Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
tensor that does not require grad.
"""
key = zeros_([1], requires_grad=False)
value = ones_([1])
input = {key: value}
accumulate = Accumulate()
with raises(ValueError):
accumulate(input)
def test_no_leaf_and_no_retains_grad_fails():
"""
Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
tensor that is not a leaf and that does not retain grad.
"""
key = tensor_([1.0], requires_grad=True) * 2
value = ones_([1])
input = {key: value}
accumulate = Accumulate()
with raises(ValueError):
accumulate(input)
def test_check_keys():
"""Tests that the `check_keys` method works correctly."""
key = tensor_([1.0], requires_grad=True)
accumulate = Accumulate()
output_keys = accumulate.check_keys({key})
assert output_keys == set()