Skip to content

Commit 064d3a3

Browse files
authored
test: Fix automatic device (#391)
* Remove torch.set_default_device(DEVICE) from conftest.py * Add curried versions of functions requiring a device in _utils.py * Use curried fucntions in tests/unit * Fix device of nn.modules in tests * Add explanation about how to test on the right device in CONTRIBUTING.md
1 parent 54d9914 commit 064d3a3

34 files changed

+511
-458
lines changed

CONTRIBUTING.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,30 @@ should create it.
9999
We ask contributors to implement the unit tests necessary to check the correctness of their
100100
implementations. Besides, whenever usage examples are provided, we require the example's code to be
101101
tested in `tests/doc`. We require a very high code coverage for newly introduced sources (~95-100%).
102+
To ensure that the tensors generated during the tests are on the right device, you have to use the
103+
partial functions defined in `tests/unit/_utils.py` to instantiate tensors. For instance, instead of
104+
```python
105+
import torch
106+
a = torch.ones(3, 4)
107+
```
108+
use
109+
```python
110+
from unit._utils import ones_
111+
a = ones_(3, 4)
112+
```
113+
114+
This will automatically call `torch.ones` with `device=unit.conftest.DEVICE`.
115+
If the function you need does not exist yet as a partial function in `_utils.py`, add it.
116+
Lastly, when you create a model or a random generator, you have to move them manually to the right
117+
device (the `DEVICE` defined in `unit.conftest`):
118+
```python
119+
import torch
120+
from torch.nn import Linear
121+
from unit.conftest import DEVICE
122+
123+
model = Linear(3, 4).to(device=DEVICE)
124+
rng = torch.Generator(device=DEVICE)
125+
```
102126
103127
### Coding
104128

tests/unit/_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
11
from contextlib import AbstractContextManager
2+
from functools import partial
23
from typing import TypeAlias
34

5+
import torch
6+
from unit.conftest import DEVICE
7+
48
ExceptionContext: TypeAlias = AbstractContextManager[Exception | None]
9+
10+
# Curried calls to torch functions that require a device so that we automatically fix the device
11+
# for code written in the tests, while not affecting code written in src (what
12+
# torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done).
13+
14+
empty_ = partial(torch.empty, device=DEVICE)
15+
eye_ = partial(torch.eye, device=DEVICE)
16+
ones_ = partial(torch.ones, device=DEVICE)
17+
rand_ = partial(torch.rand, device=DEVICE)
18+
randint_ = partial(torch.randint, device=DEVICE)
19+
randn_ = partial(torch.randn, device=DEVICE)
20+
randperm_ = partial(torch.randperm, device=DEVICE)
21+
tensor_ = partial(torch.tensor, device=DEVICE)
22+
zeros_ = partial(torch.zeros, device=DEVICE)

tests/unit/aggregation/_asserts.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pytest import raises
33
from torch import Tensor
44
from torch.testing import assert_close
5+
from unit._utils import rand_, randperm_
56

67
from torchjd.aggregation import Aggregator
78
from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError
@@ -44,7 +45,7 @@ def assert_permutation_invariant(
4445
"""
4546

4647
def permute_randomly(matrix_: Tensor) -> Tensor:
47-
row_permutation = torch.randperm(matrix_.size(dim=0))
48+
row_permutation = randperm_(matrix_.size(dim=0))
4849
return matrix_[row_permutation]
4950

5051
vector = aggregator(matrix)
@@ -66,10 +67,10 @@ def assert_linear_under_scaling(
6667
"""Tests empirically that a given `Aggregator` satisfies the linear under scaling property."""
6768

6869
for _ in range(n_runs):
69-
c1 = torch.rand(matrix.shape[0], dtype=matrix.dtype)
70-
c2 = torch.rand(matrix.shape[0], dtype=matrix.dtype)
71-
alpha = torch.rand([], dtype=matrix.dtype)
72-
beta = torch.rand([], dtype=matrix.dtype)
70+
c1 = rand_(matrix.shape[0], dtype=matrix.dtype)
71+
c2 = rand_(matrix.shape[0], dtype=matrix.dtype)
72+
alpha = rand_([], dtype=matrix.dtype)
73+
beta = rand_([], dtype=matrix.dtype)
7374

7475
x1 = aggregator(torch.diag(c1) @ matrix)
7576
x2 = aggregator(torch.diag(c2) @ matrix)

tests/unit/aggregation/_inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from unit._utils import zeros_
23
from unit.conftest import DEVICE
34

45
from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler
@@ -31,7 +32,7 @@
3132
_rng = torch.Generator(device=DEVICE).manual_seed(0)
3233

3334
matrices = [NormalSampler(m, n, r)(_rng) for m, n, r in _normal_dims]
34-
zero_matrices = [torch.zeros([m, n]) for m, n, _ in _zero_dims]
35+
zero_matrices = [zeros_([m, n]) for m, n, _ in _zero_dims]
3536
strong_matrices = [StrongSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
3637
strictly_weak_matrices = [StrictlyWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]
3738
non_weak_matrices = [NonWeakSampler(m, n, r)(_rng) for m, n, r in _stationarity_dims]

tests/unit/aggregation/_matrix_samplers.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from torch import Tensor
55
from torch.nn.functional import normalize
6+
from unit._utils import randint_, randn_, randperm_, zeros_
67

78

89
class MatrixSampler(ABC):
@@ -43,7 +44,7 @@ class NormalSampler(MatrixSampler):
4344
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
4445
U = _sample_orthonormal_matrix(self.m, dtype=self.dtype, rng=rng)
4546
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
46-
S = torch.diag(torch.abs(torch.randn([self.rank], dtype=self.dtype, generator=rng)))
47+
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
4748
A = U[:, : self.rank] @ S @ Vt[: self.rank, :]
4849
return A
4950

@@ -65,11 +66,11 @@ def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
6566
assert 0 < rank <= min(m - 1, n)
6667

6768
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
68-
v = torch.abs(torch.randn([self.m], dtype=self.dtype, generator=rng))
69+
v = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
6970
U1 = normalize(v, dim=0).unsqueeze(1)
7071
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
7172
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
72-
S = torch.diag(torch.abs(torch.randn([self.rank], dtype=self.dtype, generator=rng)))
73+
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
7374
A = U2[:, : self.rank] @ S @ Vt[: self.rank, :]
7475
return A
7576

@@ -101,18 +102,18 @@ def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
101102
assert 0 < rank <= min(m - 1, n)
102103

103104
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
104-
u = torch.abs(torch.randn([self.m], dtype=self.dtype, generator=rng))
105-
split_index = torch.randint(1, self.m, [], generator=rng).item()
106-
shuffled_range = torch.randperm(self.m, generator=rng)
107-
v = torch.zeros(self.m, dtype=self.dtype)
105+
u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
106+
split_index = randint_(1, self.m, [], generator=rng).item()
107+
shuffled_range = randperm_(self.m, generator=rng)
108+
v = zeros_(self.m, dtype=self.dtype)
108109
v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0)
109-
v_prime = torch.zeros(self.m, dtype=self.dtype)
110+
v_prime = zeros_(self.m, dtype=self.dtype)
110111
v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0)
111112
U1 = torch.stack([v, v_prime]).T
112113
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
113114
U = torch.hstack([U1, U2])
114115
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
115-
S = torch.diag(torch.abs(torch.randn([self.rank], dtype=self.dtype, generator=rng)))
116+
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
116117
A = U[:, 1 : self.rank + 1] @ S @ Vt[: self.rank, :]
117118
return A
118119

@@ -133,12 +134,12 @@ def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
133134
assert 0 < rank
134135

135136
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
136-
u = torch.abs(torch.randn([self.m], dtype=self.dtype, generator=rng))
137+
u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
137138
U1 = normalize(u, dim=0).unsqueeze(1)
138139
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
139140
U = torch.hstack([U1, U2])
140141
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
141-
S = torch.diag(torch.abs(torch.randn([self.rank], dtype=self.dtype, generator=rng)))
142+
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
142143
A = U[:, : self.rank] @ S @ Vt[: self.rank, :]
143144
return A
144145

@@ -148,7 +149,7 @@ def _sample_orthonormal_matrix(
148149
) -> Tensor:
149150
"""Uniformly samples a random orthonormal matrix of shape [dim, dim]."""
150151

151-
return _sample_semi_orthonormal_complement(torch.zeros([dim, 0], dtype=dtype), rng=rng)
152+
return _sample_semi_orthonormal_complement(zeros_([dim, 0], dtype=dtype), rng=rng)
152153

153154

154155
def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None = None) -> Tensor:
@@ -161,7 +162,7 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None =
161162

162163
dtype = Q.dtype
163164
m, k = Q.shape
164-
A = torch.randn([m, m - k], dtype=dtype, generator=rng)
165+
A = randn_([m, m - k], dtype=dtype, generator=rng)
165166

166167
# project A onto the orthogonal complement of Q
167168
A_proj = A - Q @ (Q.T @ A)

tests/unit/aggregation/_utils/test_dual_cone.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from pytest import mark, raises
44
from torch.testing import assert_close
5+
from unit._utils import rand_, randn_
56

67
from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights
78

@@ -29,9 +30,9 @@ def test_solution_weights(shape: tuple[int, int]):
2930
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
3031
"""
3132

32-
J = torch.randn(shape)
33+
J = randn_(shape)
3334
G = J @ J.T
34-
u = torch.rand(shape[0])
35+
u = rand_(shape[0])
3536

3637
w = project_weights(u, G, "quadprog")
3738
dual_gap = w - u
@@ -58,9 +59,9 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float):
5859
Tests that `_project_weights` is invariant under scaling.
5960
"""
6061

61-
J = torch.randn(shape)
62+
J = randn_(shape)
6263
G = J @ J.T
63-
u = torch.rand(shape[0])
64+
u = rand_(shape[0])
6465

6566
w = project_weights(u, G, "quadprog")
6667
w_scaled = project_weights(u, scaling * G, "quadprog")
@@ -75,8 +76,8 @@ def test_tensorization_shape(shape: tuple[int, ...]):
7576
reshaped as matrix and to reshape the result back to the original tensor's shape.
7677
"""
7778

78-
matrix = torch.randn([shape[-1], shape[-1]])
79-
U_tensor = torch.randn(shape)
79+
matrix = randn_([shape[-1], shape[-1]])
80+
U_tensor = randn_(shape)
8081
U_matrix = U_tensor.reshape([-1, shape[-1]])
8182

8283
G = matrix @ matrix.T

tests/unit/aggregation/_utils/test_pref_vector.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from contextlib import nullcontext as does_not_raise
22

3-
import torch
43
from pytest import mark, raises
54
from torch import Tensor
6-
from unit._utils import ExceptionContext
5+
from unit._utils import ExceptionContext, ones_
76

87
from torchjd.aggregation._mean import _MeanWeighting
98
from torchjd.aggregation._utils.pref_vector import pref_vector_to_weighting
@@ -13,12 +12,12 @@
1312
["pref_vector", "expectation"],
1413
[
1514
(None, does_not_raise()),
16-
(torch.ones([]), raises(ValueError)),
17-
(torch.ones([0]), does_not_raise()),
18-
(torch.ones([1]), does_not_raise()),
19-
(torch.ones([5]), does_not_raise()),
20-
(torch.ones([1, 1]), raises(ValueError)),
21-
(torch.ones([1, 1, 1]), raises(ValueError)),
15+
(ones_([]), raises(ValueError)),
16+
(ones_([0]), does_not_raise()),
17+
(ones_([1]), does_not_raise()),
18+
(ones_([5]), does_not_raise()),
19+
(ones_([1, 1]), raises(ValueError)),
20+
(ones_([1, 1, 1]), raises(ValueError)),
2221
],
2322
)
2423
def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext):

tests/unit/aggregation/test_aggregator_bases.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from collections.abc import Sequence
22
from contextlib import nullcontext as does_not_raise
33

4-
import torch
54
from pytest import mark, raises
6-
from unit._utils import ExceptionContext
5+
from unit._utils import ExceptionContext, randn_
76

87
from torchjd.aggregation import Aggregator
98

@@ -20,4 +19,4 @@
2019
)
2120
def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext):
2221
with expectation:
23-
Aggregator._check_is_matrix(torch.randn(shape))
22+
Aggregator._check_is_matrix(randn_(shape))

tests/unit/aggregation/test_cagrad.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from contextlib import nullcontext as does_not_raise
22

3-
import torch
43
from pytest import mark, raises
54
from torch import Tensor
6-
from unit._utils import ExceptionContext
5+
from unit._utils import ExceptionContext, ones_
76

87
from torchjd.aggregation import CAGrad
98

@@ -12,7 +11,7 @@
1211

1312
scaled_pairs = [(CAGrad(c=0.5), matrix) for matrix in scaled_matrices]
1413
typical_pairs = [(CAGrad(c=0.5), matrix) for matrix in typical_matrices]
15-
requires_grad_pairs = [(CAGrad(c=0.5), torch.ones(3, 5, requires_grad=True))]
14+
requires_grad_pairs = [(CAGrad(c=0.5), ones_(3, 5, requires_grad=True))]
1615
non_conflicting_pairs_1 = [(CAGrad(c=1.0), matrix) for matrix in typical_matrices]
1716
non_conflicting_pairs_2 = [(CAGrad(c=2.0), matrix) for matrix in typical_matrices]
1817

tests/unit/aggregation/test_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from pytest import mark
33
from torch import Tensor
4+
from unit._utils import ones_
45

56
from torchjd.aggregation import ConFIG
67

@@ -15,7 +16,7 @@
1516
scaled_pairs = [(ConFIG(), matrix) for matrix in scaled_matrices]
1617
typical_pairs = [(ConFIG(), matrix) for matrix in typical_matrices]
1718
non_strong_pairs = [(ConFIG(), matrix) for matrix in non_strong_matrices]
18-
requires_grad_pairs = [(ConFIG(), torch.ones(3, 5, requires_grad=True))]
19+
requires_grad_pairs = [(ConFIG(), ones_(3, 5, requires_grad=True))]
1920

2021

2122
@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)

0 commit comments

Comments
 (0)