Skip to content

Commit 5cf8c1c

Browse files
committed
Ad jac_to_grad tests
1 parent c353713 commit 5cf8c1c

1 file changed

Lines changed: 103 additions & 0 deletions

File tree

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from pytest import mark, raises
2+
from unit.autojac._asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
3+
from utils.tensors import tensor_
4+
5+
from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
6+
from torchjd.autojac._jac_to_grad import jac_to_grad
7+
8+
9+
@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()])
10+
def test_various_aggregators(aggregator: Aggregator):
11+
"""Tests that jac_to_grad works for various aggregators."""
12+
13+
t1 = tensor_(1.0, requires_grad=True)
14+
t2 = tensor_([2.0, 3.0], requires_grad=True)
15+
jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
16+
t1.__setattr__("jac", jac[:, 0])
17+
t2.__setattr__("jac", jac[:, 1:])
18+
expected_grad = aggregator(jac)
19+
g1 = expected_grad[0]
20+
g2 = expected_grad[1:]
21+
22+
jac_to_grad([t1, t2], aggregator)
23+
24+
assert_grad_close(t1, g1)
25+
assert_grad_close(t2, g2)
26+
27+
28+
def test_single_tensor():
29+
"""Tests that jac_to_grad works when a single tensor is provided."""
30+
31+
aggregator = UPGrad()
32+
t = tensor_([2.0, 3.0, 4.0], requires_grad=True)
33+
jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
34+
t.__setattr__("jac", jac)
35+
g = aggregator(jac)
36+
37+
jac_to_grad([t], aggregator)
38+
39+
assert_grad_close(t, g)
40+
41+
42+
def test_no_jac_field():
43+
"""Tests that jac_to_grad fails when a tensor does not have a jac field."""
44+
45+
aggregator = UPGrad()
46+
t1 = tensor_(1.0, requires_grad=True)
47+
t2 = tensor_([2.0, 3.0], requires_grad=True)
48+
jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
49+
t2.__setattr__("jac", jac[:, 1:])
50+
51+
with raises(ValueError):
52+
jac_to_grad([t1, t2], aggregator)
53+
54+
55+
def test_no_requires_grad():
56+
"""Tests that jac_to_grad fails when a tensor does not require grad."""
57+
58+
aggregator = UPGrad()
59+
t1 = tensor_(1.0, requires_grad=True)
60+
t2 = tensor_([2.0, 3.0], requires_grad=False)
61+
jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
62+
t1.__setattr__("jac", jac[:, 0])
63+
t2.__setattr__("jac", jac[:, 1:])
64+
65+
with raises(ValueError):
66+
jac_to_grad([t1, t2], aggregator)
67+
68+
69+
def test_row_mismatch():
70+
"""Tests that jac_to_grad fails when the number of rows of the .jac is not constant."""
71+
72+
aggregator = UPGrad()
73+
t1 = tensor_(1.0, requires_grad=True)
74+
t2 = tensor_([2.0, 3.0], requires_grad=True)
75+
t1.__setattr__("jac", tensor_([5.0, 6.0, 7.0])) # 3 rows
76+
t2.__setattr__("jac", tensor_([[1.0, 2.0], [3.0, 4.0]])) # 2 rows
77+
78+
with raises(ValueError):
79+
jac_to_grad([t1, t2], aggregator)
80+
81+
82+
def test_no_tensors():
83+
"""Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided."""
84+
85+
jac_to_grad([], aggregator=UPGrad())
86+
87+
88+
@mark.parametrize("retain_jac", [True, False])
89+
def test_jacs_are_freed(retain_jac: bool):
90+
"""Tests that jac_to_grad frees the jac fields if an only if retain_jac is False."""
91+
92+
aggregator = UPGrad()
93+
t1 = tensor_(1.0, requires_grad=True)
94+
t2 = tensor_([2.0, 3.0], requires_grad=True)
95+
jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
96+
t1.__setattr__("jac", jac[:, 0])
97+
t2.__setattr__("jac", jac[:, 1:])
98+
99+
jac_to_grad([t1, t2], aggregator, retain_jac=retain_jac)
100+
101+
check = assert_has_jac if retain_jac else assert_has_no_jac
102+
check(t1)
103+
check(t2)

0 commit comments

Comments
 (0)