22from utils .dict_assertions import assert_tensor_dicts_are_close
33from utils .tensors import ones_ , tensor_ , zeros_
44
5- from torchjd .autojac ._transform import AccumulateGrad
5+ from torchjd .autojac ._transform import AccumulateGrad , AccumulateJac
66
77
8- def test_single_accumulation ():
8+ def test_single_grad_accumulation ():
99 """
1010 Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run
1111 once.
@@ -33,7 +33,7 @@ def test_single_accumulation():
3333
3434
3535@mark .parametrize ("iterations" , [1 , 2 , 4 , 10 , 13 ])
36- def test_multiple_accumulation (iterations : int ):
36+ def test_multiple_grad_accumulations (iterations : int ):
3737 """
3838 Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run
3939 `iterations` times.
@@ -63,7 +63,7 @@ def test_multiple_accumulation(iterations: int):
6363 assert_tensor_dicts_are_close (grads , expected_grads )
6464
6565
66- def test_no_requires_grad_fails ():
66+ def test_accumulate_grad_fails_when_no_requires_grad ():
6767 """
6868 Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a
6969 tensor that does not require grad.
@@ -79,7 +79,7 @@ def test_no_requires_grad_fails():
7979 accumulate (input )
8080
8181
82- def test_no_leaf_and_no_retains_grad_fails ():
82+ def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad ():
8383 """
8484 Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a
8585 tensor that is not a leaf and that does not retain grad.
@@ -95,11 +95,127 @@ def test_no_leaf_and_no_retains_grad_fails():
9595 accumulate (input )
9696
9797
98- def test_check_keys ():
99- """Tests that the `check_keys` method works correctly."""
98+ def test_accumulate_grad_check_keys ():
99+ """Tests that the `check_keys` method works correctly for AccumulateGrad ."""
100100
101101 key = tensor_ ([1.0 ], requires_grad = True )
102102 accumulate = AccumulateGrad ()
103103
104104 output_keys = accumulate .check_keys ({key })
105105 assert output_keys == set ()
106+
107+
108+ def test_single_jac_accumulation ():
109+ """
110+ Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run
111+ once.
112+ """
113+
114+ key1 = zeros_ ([], requires_grad = True )
115+ key2 = zeros_ ([1 ], requires_grad = True )
116+ key3 = zeros_ ([2 , 3 ], requires_grad = True )
117+ value1 = ones_ ([4 ])
118+ value2 = ones_ ([4 , 1 ])
119+ value3 = ones_ ([4 , 2 , 3 ])
120+ input = {key1 : value1 , key2 : value2 , key3 : value3 }
121+
122+ accumulate = AccumulateJac ()
123+
124+ output = accumulate (input )
125+ expected_output = {}
126+
127+ assert_tensor_dicts_are_close (output , expected_output )
128+
129+ jacs = {key1 : key1 .jac , key2 : key2 .jac , key3 : key3 .jac }
130+ expected_jacs = {key1 : value1 , key2 : value2 , key3 : value3 }
131+
132+ assert_tensor_dicts_are_close (jacs , expected_jacs )
133+
134+
135+ @mark .parametrize ("iterations" , [1 , 2 , 4 , 10 , 13 ])
136+ def test_multiple_jac_accumulations (iterations : int ):
137+ """
138+ Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run
139+ `iterations` times.
140+ """
141+
142+ key1 = zeros_ ([], requires_grad = True )
143+ key2 = zeros_ ([1 ], requires_grad = True )
144+ key3 = zeros_ ([2 , 3 ], requires_grad = True )
145+ value1 = ones_ ([4 ])
146+ value2 = ones_ ([4 , 1 ])
147+ value3 = ones_ ([4 , 2 , 3 ])
148+
149+ accumulate = AccumulateJac ()
150+
151+ for i in range (iterations ):
152+ # Clone values to ensure that we accumulate values that are not ever used afterwards
153+ input = {key1 : value1 .clone (), key2 : value2 .clone (), key3 : value3 .clone ()}
154+ accumulate (input )
155+
156+ jacs = {key1 : key1 .jac , key2 : key2 .jac , key3 : key3 .jac }
157+ expected_jacs = {
158+ key1 : iterations * value1 ,
159+ key2 : iterations * value2 ,
160+ key3 : iterations * value3 ,
161+ }
162+
163+ assert_tensor_dicts_are_close (jacs , expected_jacs )
164+
165+
166+ def test_accumulate_jac_fails_when_no_requires_grad ():
167+ """
168+ Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a
169+ tensor that does not require grad.
170+ """
171+
172+ key = zeros_ ([1 ], requires_grad = False )
173+ value = ones_ ([4 , 1 ])
174+ input = {key : value }
175+
176+ accumulate = AccumulateJac ()
177+
178+ with raises (ValueError ):
179+ accumulate (input )
180+
181+
182+ def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad ():
183+ """
184+ Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a
185+ tensor that is not a leaf and that does not retain grad.
186+ """
187+
188+ key = tensor_ ([1.0 ], requires_grad = True ) * 2
189+ value = ones_ ([4 , 1 ])
190+ input = {key : value }
191+
192+ accumulate = AccumulateJac ()
193+
194+ with raises (ValueError ):
195+ accumulate (input )
196+
197+
198+ def test_accumulate_jac_fails_when_shape_mismatch ():
199+ """
200+ Tests that the AccumulateJac transform raises an error when the jacobian shape does not match
201+ the parameter shape (ignoring the first dimension).
202+ """
203+
204+ key = zeros_ ([2 , 3 ], requires_grad = True )
205+ value = ones_ ([4 , 3 , 2 ]) # Wrong shape: should be [4, 2, 3], not [4, 3, 2]
206+ input = {key : value }
207+
208+ accumulate = AccumulateJac ()
209+
210+ with raises (RuntimeError ):
211+ accumulate (input )
212+
213+
214+ def test_accumulate_jac_check_keys ():
215+ """Tests that the `check_keys` method works correctly for AccumulateJac."""
216+
217+ key = tensor_ ([1.0 ], requires_grad = True )
218+ accumulate = AccumulateJac ()
219+
220+ output_keys = accumulate .check_keys ({key })
221+ assert output_keys == set ()
0 commit comments