Skip to content

Commit 84bd552

Browse files
committed
Move asserts to tests/utils and use them in doc tests
1 parent cff6d8e commit 84bd552

7 files changed

Lines changed: 13 additions & 15 deletions

File tree

tests/doc/test_backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
obtained `.jac` field.
44
"""
55

6-
from torch.testing import assert_close
6+
from utils.asserts import assert_jac_close
77

88

99
def test_backward():
@@ -18,4 +18,4 @@ def test_backward():
1818

1919
backward([y1, y2])
2020

21-
assert_close(param.jac, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)
21+
assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)

tests/doc/test_jac_to_grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
the obtained `.grad` field.
44
"""
55

6-
from torch.testing import assert_close
6+
from utils.asserts import assert_grad_close
77

88

99
def test_jac_to_grad():
@@ -19,4 +19,4 @@ def test_jac_to_grad():
1919
backward([y1, y2]) # param now has a .jac field
2020
jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
2121

22-
assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)
22+
assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pytest import mark, raises
2-
from unit.autojac._asserts import assert_grad_close, assert_jac_close
2+
from utils.asserts import assert_grad_close, assert_jac_close
33
from utils.dict_assertions import assert_tensor_dicts_are_close
44
from utils.tensors import ones_, tensor_, zeros_
55

tests/unit/autojac/test_backward.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
22
from pytest import mark, raises
3+
from utils.asserts import assert_has_jac, assert_has_no_jac, assert_jac_close
34
from utils.tensors import randn_, tensor_
45

56
from torchjd.autojac import backward
67
from torchjd.autojac._backward import _create_transform
78
from torchjd.autojac._transform import OrderedSet
89

9-
from ._asserts import assert_has_jac, assert_has_no_jac, assert_jac_close
10-
1110

1211
def test_check_create_transform():
1312
"""Tests that _create_transform creates a valid Transform."""

tests/unit/autojac/test_jac_to_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pytest import mark, raises
2-
from unit.autojac._asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
2+
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
33
from utils.tensors import tensor_
44

55
from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad

tests/unit/autojac/test_mtl_backward.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,19 @@
22
from pytest import mark, raises
33
from settings import DTYPE
44
from torch.autograd import grad
5-
from utils.tensors import arange_, rand_, randn_, tensor_
6-
7-
from torchjd.autojac import mtl_backward
8-
from torchjd.autojac._mtl_backward import _create_transform
9-
from torchjd.autojac._transform import OrderedSet
10-
11-
from ._asserts import (
5+
from utils.asserts import (
126
assert_grad_close,
137
assert_has_grad,
148
assert_has_jac,
159
assert_has_no_grad,
1610
assert_has_no_jac,
1711
assert_jac_close,
1812
)
13+
from utils.tensors import arange_, rand_, randn_, tensor_
14+
15+
from torchjd.autojac import mtl_backward
16+
from torchjd.autojac._mtl_backward import _create_transform
17+
from torchjd.autojac._transform import OrderedSet
1918

2019

2120
def test_check_create_transform():

0 commit comments

Comments
 (0)