Skip to content

Commit 4f24e39

Browse files
committed
Refactor accumulate tests to use loops and assert helpers
1 parent 1876351 commit 4f24e39

1 file changed

Lines changed: 27 additions & 59 deletions

File tree

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 27 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pytest import mark, raises
2+
from unit.autojac._asserts import assert_grad_close, assert_jac_close
23
from utils.dict_assertions import assert_tensor_dicts_are_close
34
from utils.tensors import ones_, tensor_, zeros_
45

@@ -11,25 +12,18 @@ def test_single_grad_accumulation():
1112
once.
1213
"""
1314

14-
key1 = zeros_([], requires_grad=True)
15-
key2 = zeros_([1], requires_grad=True)
16-
key3 = zeros_([2, 3], requires_grad=True)
17-
value1 = ones_([])
18-
value2 = ones_([1])
19-
value3 = ones_([2, 3])
20-
input = {key1: value1, key2: value2, key3: value3}
15+
shapes = [[], [1], [2, 3]]
16+
keys = [zeros_(shape, requires_grad=True) for shape in shapes]
17+
values = [ones_(shape) for shape in shapes]
18+
input = dict(zip(keys, values))
2119

2220
accumulate = AccumulateGrad()
2321

2422
output = accumulate(input)
25-
expected_output = {}
23+
assert_tensor_dicts_are_close(output, {})
2624

27-
assert_tensor_dicts_are_close(output, expected_output)
28-
29-
grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}
30-
expected_grads = {key1: value1, key2: value2, key3: value3}
31-
32-
assert_tensor_dicts_are_close(grads, expected_grads)
25+
for key, value in zip(keys, values):
26+
assert_grad_close(key, value)
3327

3428

3529
@mark.parametrize("iterations", [1, 2, 4, 10, 13])
@@ -39,28 +33,18 @@ def test_multiple_grad_accumulations(iterations: int):
3933
`iterations` times.
4034
"""
4135

42-
key1 = zeros_([], requires_grad=True)
43-
key2 = zeros_([1], requires_grad=True)
44-
key3 = zeros_([2, 3], requires_grad=True)
45-
value1 = ones_([])
46-
value2 = ones_([1])
47-
value3 = ones_([2, 3])
48-
36+
shapes = [[], [1], [2, 3]]
37+
keys = [zeros_(shape, requires_grad=True) for shape in shapes]
38+
values = [ones_(shape) for shape in shapes]
4939
accumulate = AccumulateGrad()
5040

5141
for i in range(iterations):
5242
# Clone values to ensure that we accumulate values that are not ever used afterwards
53-
input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()}
43+
input = {key: value.clone() for key, value in zip(keys, values)}
5444
accumulate(input)
5545

56-
grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}
57-
expected_grads = {
58-
key1: iterations * value1,
59-
key2: iterations * value2,
60-
key3: iterations * value3,
61-
}
62-
63-
assert_tensor_dicts_are_close(grads, expected_grads)
46+
for key, value in zip(keys, values):
47+
assert_grad_close(key, iterations * value)
6448

6549

6650
def test_accumulate_grad_fails_when_no_requires_grad():
@@ -111,25 +95,18 @@ def test_single_jac_accumulation():
11195
once.
11296
"""
11397

114-
key1 = zeros_([], requires_grad=True)
115-
key2 = zeros_([1], requires_grad=True)
116-
key3 = zeros_([2, 3], requires_grad=True)
117-
value1 = ones_([4])
118-
value2 = ones_([4, 1])
119-
value3 = ones_([4, 2, 3])
120-
input = {key1: value1, key2: value2, key3: value3}
98+
shapes = [[], [1], [2, 3]]
99+
keys = [zeros_(shape, requires_grad=True) for shape in shapes]
100+
values = [ones_([4] + shape) for shape in shapes]
101+
input = dict(zip(keys, values))
121102

122103
accumulate = AccumulateJac()
123104

124105
output = accumulate(input)
125-
expected_output = {}
106+
assert_tensor_dicts_are_close(output, {})
126107

127-
assert_tensor_dicts_are_close(output, expected_output)
128-
129-
jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac}
130-
expected_jacs = {key1: value1, key2: value2, key3: value3}
131-
132-
assert_tensor_dicts_are_close(jacs, expected_jacs)
108+
for key, value in zip(keys, values):
109+
assert_jac_close(key, value)
133110

134111

135112
@mark.parametrize("iterations", [1, 2, 4, 10, 13])
@@ -139,28 +116,19 @@ def test_multiple_jac_accumulations(iterations: int):
139116
`iterations` times.
140117
"""
141118

142-
key1 = zeros_([], requires_grad=True)
143-
key2 = zeros_([1], requires_grad=True)
144-
key3 = zeros_([2, 3], requires_grad=True)
145-
value1 = ones_([4])
146-
value2 = ones_([4, 1])
147-
value3 = ones_([4, 2, 3])
119+
shapes = [[], [1], [2, 3]]
120+
keys = [zeros_(shape, requires_grad=True) for shape in shapes]
121+
values = [ones_([4] + shape) for shape in shapes]
148122

149123
accumulate = AccumulateJac()
150124

151125
for i in range(iterations):
152126
# Clone values to ensure that we accumulate values that are not ever used afterwards
153-
input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()}
127+
input = {key: value.clone() for key, value in zip(keys, values)}
154128
accumulate(input)
155129

156-
jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac}
157-
expected_jacs = {
158-
key1: iterations * value1,
159-
key2: iterations * value2,
160-
key3: iterations * value3,
161-
}
162-
163-
assert_tensor_dicts_are_close(jacs, expected_jacs)
130+
for key, value in zip(keys, values):
131+
assert_jac_close(key, iterations * value)
164132

165133

166134
def test_accumulate_jac_fails_when_no_requires_grad():

0 commit comments

Comments
 (0)