Skip to content

Commit 936d1fa

Browse files
authored
test(autojac): Fix wrong device in a test (#495)
1 parent 85bf67e commit 936d1fa

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

tests/unit/autojac/test_mtl_backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytest import mark, raises
33
from torch.autograd import grad
44
from torch.testing import assert_close
5-
from utils.tensors import rand_, randn_, tensor_
5+
from utils.tensors import arange_, rand_, randn_, tensor_
66

77
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
88
from torchjd.autojac import mtl_backward
@@ -345,7 +345,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]):
345345
"""Tests that mtl_backward works correctly with various kinds of feature lists."""
346346

347347
p0 = tensor_([1.0, 2.0], requires_grad=True)
348-
p1 = torch.arange(len(shapes), dtype=torch.float32, requires_grad=True)
348+
p1 = arange_(len(shapes), dtype=torch.float32, requires_grad=True)
349349
p2 = tensor_(5.0, requires_grad=True)
350350

351351
features = [rand_(shape) @ p0 for shape in shapes]

tests/utils/tensors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# for code written in the tests, while not affecting code written in src (what
1212
# torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done).
1313

14+
arange_ = partial(torch.arange, device=DEVICE)
1415
empty_ = partial(torch.empty, device=DEVICE)
1516
eye_ = partial(torch.eye, device=DEVICE)
1617
ones_ = partial(torch.ones, device=DEVICE)

0 commit comments

Comments
 (0)