Skip to content

Commit 1876351

Browse files
ValerianReyclaude
andcommitted
Add unit tests for AccumulateJac
Rename existing AccumulateGrad tests to make naming explicit and add corresponding tests for AccumulateJac, including a shape mismatch test. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 7e95c37 commit 1876351

1 file changed

Lines changed: 123 additions & 7 deletions

File tree

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from utils.dict_assertions import assert_tensor_dicts_are_close
33
from utils.tensors import ones_, tensor_, zeros_
44

5-
from torchjd.autojac._transform import AccumulateGrad
5+
from torchjd.autojac._transform import AccumulateGrad, AccumulateJac
66

77

8-
def test_single_accumulation():
8+
def test_single_grad_accumulation():
99
"""
1010
Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run
1111
once.
@@ -33,7 +33,7 @@ def test_single_accumulation():
3333

3434

3535
@mark.parametrize("iterations", [1, 2, 4, 10, 13])
36-
def test_multiple_accumulation(iterations: int):
36+
def test_multiple_grad_accumulations(iterations: int):
3737
"""
3838
Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run
3939
`iterations` times.
@@ -63,7 +63,7 @@ def test_multiple_accumulation(iterations: int):
6363
assert_tensor_dicts_are_close(grads, expected_grads)
6464

6565

66-
def test_no_requires_grad_fails():
66+
def test_accumulate_grad_fails_when_no_requires_grad():
6767
"""
6868
Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a
6969
tensor that does not require grad.
@@ -79,7 +79,7 @@ def test_no_requires_grad_fails():
7979
accumulate(input)
8080

8181

82-
def test_no_leaf_and_no_retains_grad_fails():
82+
def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad():
8383
"""
8484
Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a
8585
tensor that is not a leaf and that does not retain grad.
@@ -95,11 +95,127 @@ def test_no_leaf_and_no_retains_grad_fails():
9595
accumulate(input)
9696

9797

98-
def test_check_keys():
99-
"""Tests that the `check_keys` method works correctly."""
98+
def test_accumulate_grad_check_keys():
99+
"""Tests that the `check_keys` method works correctly for AccumulateGrad."""
100100

101101
key = tensor_([1.0], requires_grad=True)
102102
accumulate = AccumulateGrad()
103103

104104
output_keys = accumulate.check_keys({key})
105105
assert output_keys == set()
106+
107+
108+
def test_single_jac_accumulation():
109+
"""
110+
Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run
111+
once.
112+
"""
113+
114+
key1 = zeros_([], requires_grad=True)
115+
key2 = zeros_([1], requires_grad=True)
116+
key3 = zeros_([2, 3], requires_grad=True)
117+
value1 = ones_([4])
118+
value2 = ones_([4, 1])
119+
value3 = ones_([4, 2, 3])
120+
input = {key1: value1, key2: value2, key3: value3}
121+
122+
accumulate = AccumulateJac()
123+
124+
output = accumulate(input)
125+
expected_output = {}
126+
127+
assert_tensor_dicts_are_close(output, expected_output)
128+
129+
jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac}
130+
expected_jacs = {key1: value1, key2: value2, key3: value3}
131+
132+
assert_tensor_dicts_are_close(jacs, expected_jacs)
133+
134+
135+
@mark.parametrize("iterations", [1, 2, 4, 10, 13])
136+
def test_multiple_jac_accumulations(iterations: int):
137+
"""
138+
Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run
139+
`iterations` times.
140+
"""
141+
142+
key1 = zeros_([], requires_grad=True)
143+
key2 = zeros_([1], requires_grad=True)
144+
key3 = zeros_([2, 3], requires_grad=True)
145+
value1 = ones_([4])
146+
value2 = ones_([4, 1])
147+
value3 = ones_([4, 2, 3])
148+
149+
accumulate = AccumulateJac()
150+
151+
for i in range(iterations):
152+
# Clone values to ensure that we accumulate values that are not ever used afterwards
153+
input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()}
154+
accumulate(input)
155+
156+
jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac}
157+
expected_jacs = {
158+
key1: iterations * value1,
159+
key2: iterations * value2,
160+
key3: iterations * value3,
161+
}
162+
163+
assert_tensor_dicts_are_close(jacs, expected_jacs)
164+
165+
166+
def test_accumulate_jac_fails_when_no_requires_grad():
167+
"""
168+
Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a
169+
tensor that does not require grad.
170+
"""
171+
172+
key = zeros_([1], requires_grad=False)
173+
value = ones_([4, 1])
174+
input = {key: value}
175+
176+
accumulate = AccumulateJac()
177+
178+
with raises(ValueError):
179+
accumulate(input)
180+
181+
182+
def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad():
183+
"""
184+
Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a
185+
tensor that is not a leaf and that does not retain grad.
186+
"""
187+
188+
key = tensor_([1.0], requires_grad=True) * 2
189+
value = ones_([4, 1])
190+
input = {key: value}
191+
192+
accumulate = AccumulateJac()
193+
194+
with raises(ValueError):
195+
accumulate(input)
196+
197+
198+
def test_accumulate_jac_fails_when_shape_mismatch():
199+
"""
200+
Tests that the AccumulateJac transform raises an error when the jacobian shape does not match
201+
the parameter shape (ignoring the first dimension).
202+
"""
203+
204+
key = zeros_([2, 3], requires_grad=True)
205+
value = ones_([4, 3, 2]) # Wrong shape: should be [4, 2, 3], not [4, 3, 2]
206+
input = {key: value}
207+
208+
accumulate = AccumulateJac()
209+
210+
with raises(RuntimeError):
211+
accumulate(input)
212+
213+
214+
def test_accumulate_jac_check_keys():
215+
"""Tests that the `check_keys` method works correctly for AccumulateJac."""
216+
217+
key = tensor_([1.0], requires_grad=True)
218+
accumulate = AccumulateJac()
219+
220+
output_keys = accumulate.check_keys({key})
221+
assert output_keys == set()

0 commit comments

Comments
 (0)