11from pytest import mark , raises
2+ from unit .autojac ._asserts import assert_grad_close , assert_jac_close
23from utils .dict_assertions import assert_tensor_dicts_are_close
34from utils .tensors import ones_ , tensor_ , zeros_
45
@@ -11,25 +12,18 @@ def test_single_grad_accumulation():
1112 once.
1213 """
1314
14- key1 = zeros_ ([], requires_grad = True )
15- key2 = zeros_ ([1 ], requires_grad = True )
16- key3 = zeros_ ([2 , 3 ], requires_grad = True )
17- value1 = ones_ ([])
18- value2 = ones_ ([1 ])
19- value3 = ones_ ([2 , 3 ])
20- input = {key1 : value1 , key2 : value2 , key3 : value3 }
15+ shapes = [[], [1 ], [2 , 3 ]]
16+ keys = [zeros_ (shape , requires_grad = True ) for shape in shapes ]
17+ values = [ones_ (shape ) for shape in shapes ]
18+ input = dict (zip (keys , values ))
2119
2220 accumulate = AccumulateGrad ()
2321
2422 output = accumulate (input )
25- expected_output = {}
23+ assert_tensor_dicts_are_close ( output , {})
2624
27- assert_tensor_dicts_are_close (output , expected_output )
28-
29- grads = {key1 : key1 .grad , key2 : key2 .grad , key3 : key3 .grad }
30- expected_grads = {key1 : value1 , key2 : value2 , key3 : value3 }
31-
32- assert_tensor_dicts_are_close (grads , expected_grads )
25+ for key , value in zip (keys , values ):
26+ assert_grad_close (key , value )
3327
3428
3529@mark .parametrize ("iterations" , [1 , 2 , 4 , 10 , 13 ])
@@ -39,28 +33,18 @@ def test_multiple_grad_accumulations(iterations: int):
3933 `iterations` times.
4034 """
4135
42- key1 = zeros_ ([], requires_grad = True )
43- key2 = zeros_ ([1 ], requires_grad = True )
44- key3 = zeros_ ([2 , 3 ], requires_grad = True )
45- value1 = ones_ ([])
46- value2 = ones_ ([1 ])
47- value3 = ones_ ([2 , 3 ])
48-
36+ shapes = [[], [1 ], [2 , 3 ]]
37+ keys = [zeros_ (shape , requires_grad = True ) for shape in shapes ]
38+ values = [ones_ (shape ) for shape in shapes ]
4939 accumulate = AccumulateGrad ()
5040
5141 for i in range (iterations ):
5242 # Clone values to ensure that we accumulate values that are not ever used afterwards
53- input = {key1 : value1 .clone (), key2 : value2 . clone (), key3 : value3 . clone ( )}
43+ input = {key : value .clone () for key , value in zip ( keys , values )}
5444 accumulate (input )
5545
56- grads = {key1 : key1 .grad , key2 : key2 .grad , key3 : key3 .grad }
57- expected_grads = {
58- key1 : iterations * value1 ,
59- key2 : iterations * value2 ,
60- key3 : iterations * value3 ,
61- }
62-
63- assert_tensor_dicts_are_close (grads , expected_grads )
46+ for key , value in zip (keys , values ):
47+ assert_grad_close (key , iterations * value )
6448
6549
6650def test_accumulate_grad_fails_when_no_requires_grad ():
@@ -111,25 +95,18 @@ def test_single_jac_accumulation():
11195 once.
11296 """
11397
114- key1 = zeros_ ([], requires_grad = True )
115- key2 = zeros_ ([1 ], requires_grad = True )
116- key3 = zeros_ ([2 , 3 ], requires_grad = True )
117- value1 = ones_ ([4 ])
118- value2 = ones_ ([4 , 1 ])
119- value3 = ones_ ([4 , 2 , 3 ])
120- input = {key1 : value1 , key2 : value2 , key3 : value3 }
98+ shapes = [[], [1 ], [2 , 3 ]]
99+ keys = [zeros_ (shape , requires_grad = True ) for shape in shapes ]
100+ values = [ones_ ([4 ] + shape ) for shape in shapes ]
101+ input = dict (zip (keys , values ))
121102
122103 accumulate = AccumulateJac ()
123104
124105 output = accumulate (input )
125- expected_output = {}
106+ assert_tensor_dicts_are_close ( output , {})
126107
127- assert_tensor_dicts_are_close (output , expected_output )
128-
129- jacs = {key1 : key1 .jac , key2 : key2 .jac , key3 : key3 .jac }
130- expected_jacs = {key1 : value1 , key2 : value2 , key3 : value3 }
131-
132- assert_tensor_dicts_are_close (jacs , expected_jacs )
108+ for key , value in zip (keys , values ):
109+ assert_jac_close (key , value )
133110
134111
135112@mark .parametrize ("iterations" , [1 , 2 , 4 , 10 , 13 ])
@@ -139,28 +116,19 @@ def test_multiple_jac_accumulations(iterations: int):
139116 `iterations` times.
140117 """
141118
142- key1 = zeros_ ([], requires_grad = True )
143- key2 = zeros_ ([1 ], requires_grad = True )
144- key3 = zeros_ ([2 , 3 ], requires_grad = True )
145- value1 = ones_ ([4 ])
146- value2 = ones_ ([4 , 1 ])
147- value3 = ones_ ([4 , 2 , 3 ])
119+ shapes = [[], [1 ], [2 , 3 ]]
120+ keys = [zeros_ (shape , requires_grad = True ) for shape in shapes ]
121+ values = [ones_ ([4 ] + shape ) for shape in shapes ]
148122
149123 accumulate = AccumulateJac ()
150124
151125 for i in range (iterations ):
152126 # Clone values to ensure that we accumulate values that are not ever used afterwards
153- input = {key1 : value1 .clone (), key2 : value2 . clone (), key3 : value3 . clone ( )}
127+ input = {key : value .clone () for key , value in zip ( keys , values )}
154128 accumulate (input )
155129
156- jacs = {key1 : key1 .jac , key2 : key2 .jac , key3 : key3 .jac }
157- expected_jacs = {
158- key1 : iterations * value1 ,
159- key2 : iterations * value2 ,
160- key3 : iterations * value3 ,
161- }
162-
163- assert_tensor_dicts_are_close (jacs , expected_jacs )
130+ for key , value in zip (keys , values ):
131+ assert_jac_close (key , iterations * value )
164132
165133
166134def test_accumulate_jac_fails_when_no_requires_grad ():
0 commit comments