Skip to content

Commit 382e9d3

Browse files
committed
Fix tests
- Remove test_aggregate.py - Update test_accumulate.py and test_interactions.py to test on AccumulateGrad instead of Accumulate - Fix tests in test_backward.py and test_mtl_backward.py to match the new interface: check the jac field instead of the .grad field. - Use _asserts.py for helper functions common to backward.py and mtl_backward.py
1 parent 9e855e2 commit 382e9d3

6 files changed

Lines changed: 187 additions & 317 deletions

File tree

tests/unit/autojac/_asserts.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import cast
2+
3+
import torch
4+
from torch.testing import assert_close
5+
6+
from torchjd.utils._tensor_with_jac import TensorWithJac
7+
8+
9+
def assert_has_jac(t: torch.Tensor) -> None:
10+
assert hasattr(t, "jac")
11+
t_ = cast(TensorWithJac, t)
12+
assert t_.jac is not None and t_.jac.shape[1:] == t_.shape
13+
14+
15+
def assert_has_no_jac(t: torch.Tensor) -> None:
16+
assert not hasattr(t, "jac")
17+
18+
19+
def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor) -> None:
20+
assert hasattr(t, "jac")
21+
t_ = cast(TensorWithJac, t)
22+
assert_close(t_.jac, expected_jac)
23+
24+
25+
def assert_has_grad(t: torch.Tensor) -> None:
26+
assert (t.grad is not None) and (t.shape == t.grad.shape)
27+
28+
29+
def assert_has_no_grad(t: torch.Tensor) -> None:
30+
assert t.grad is None
31+
32+
33+
def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor) -> None:
34+
assert t.grad is not None
35+
assert_close(t.grad, expected_grad)

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from utils.dict_assertions import assert_tensor_dicts_are_close
33
from utils.tensors import ones_, tensor_, zeros_
44

5-
from torchjd.autojac._transform import Accumulate
5+
from torchjd.autojac._transform import AccumulateGrad
66

77

88
def test_single_accumulation():
99
"""
10-
Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run
10+
Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run
1111
once.
1212
"""
1313

@@ -19,7 +19,7 @@ def test_single_accumulation():
1919
value3 = ones_([2, 3])
2020
input = {key1: value1, key2: value2, key3: value3}
2121

22-
accumulate = Accumulate()
22+
accumulate = AccumulateGrad()
2323

2424
output = accumulate(input)
2525
expected_output = {}
@@ -35,7 +35,7 @@ def test_single_accumulation():
3535
@mark.parametrize("iterations", [1, 2, 4, 10, 13])
3636
def test_multiple_accumulation(iterations: int):
3737
"""
38-
Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run
38+
Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run
3939
`iterations` times.
4040
"""
4141

@@ -46,7 +46,7 @@ def test_multiple_accumulation(iterations: int):
4646
value2 = ones_([1])
4747
value3 = ones_([2, 3])
4848

49-
accumulate = Accumulate()
49+
accumulate = AccumulateGrad()
5050

5151
for i in range(iterations):
5252
# Clone values to ensure that we accumulate values that are not ever used afterwards
@@ -65,31 +65,31 @@ def test_multiple_accumulation(iterations: int):
6565

6666
def test_no_requires_grad_fails():
6767
"""
68-
Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
68+
Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a
6969
tensor that does not require grad.
7070
"""
7171

7272
key = zeros_([1], requires_grad=False)
7373
value = ones_([1])
7474
input = {key: value}
7575

76-
accumulate = Accumulate()
76+
accumulate = AccumulateGrad()
7777

7878
with raises(ValueError):
7979
accumulate(input)
8080

8181

8282
def test_no_leaf_and_no_retains_grad_fails():
8383
"""
84-
Tests that the Accumulate transform raises an error when it tries to populate a .grad of a
84+
Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a
8585
tensor that is not a leaf and that does not retain grad.
8686
"""
8787

8888
key = tensor_([1.0], requires_grad=True) * 2
8989
value = ones_([1])
9090
input = {key: value}
9191

92-
accumulate = Accumulate()
92+
accumulate = AccumulateGrad()
9393

9494
with raises(ValueError):
9595
accumulate(input)
@@ -99,7 +99,7 @@ def test_check_keys():
9999
"""Tests that the `check_keys` method works correctly."""
100100

101101
key = tensor_([1.0], requires_grad=True)
102-
accumulate = Accumulate()
102+
accumulate = AccumulateGrad()
103103

104104
output_keys = accumulate.check_keys({key})
105105
assert output_keys == set()

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 0 additions & 155 deletions
This file was deleted.

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from utils.tensors import tensor_, zeros_
66

77
from torchjd.autojac._transform import (
8-
Accumulate,
8+
AccumulateGrad,
99
Conjunction,
1010
Diagonalize,
1111
Grad,
@@ -186,18 +186,18 @@ def test_conjunction_is_associative():
186186

187187
def test_conjunction_accumulate_select():
188188
"""
189-
Tests that it is possible to conjunct an Accumulate and a Select in this order.
190-
It is not trivial since the type of the TensorDict returned by the first transform (Accumulate)
191-
is EmptyDict, which is not the type that the conjunction should return (Gradients), but a
192-
subclass of it.
189+
Tests that it is possible to conjunct an AccumulateGrad and a Select in this order.
190+
It is not trivial since the type of the TensorDict returned by the first transform
191+
(AccumulateGrad) is EmptyDict, which is not the type that the conjunction should return
192+
(Gradients), but a subclass of it.
193193
"""
194194

195195
key = tensor_([1.0, 2.0, 3.0], requires_grad=True)
196196
value = torch.ones_like(key)
197197
input = {key: value}
198198

199199
select = Select(set())
200-
accumulate = Accumulate()
200+
accumulate = AccumulateGrad()
201201
conjunction = accumulate | select
202202

203203
output = conjunction(input)

0 commit comments

Comments
 (0)