Skip to content

Commit 8ed91c7

Browse files
authored
test: Add dtype setting (#496)
* Rename device.py to settings.py * Add DTYPE setting through environment variable PYTEST_TORCH_DTYPE * Make all specifications of float dtype in tests use DTYPE * Fix test tolerances
1 parent 936d1fa commit 8ed91c7

File tree

13 files changed

+45
-24
lines changed

13 files changed

+45
-24
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from contextlib import nullcontext
33

44
import torch
5-
from device import DEVICE
65
from pytest import RaisesExc, fixture, mark
6+
from settings import DEVICE
77
from torch import Tensor
88
from utils.architectures import ModuleFactory
99

tests/device.py renamed to tests/settings.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,19 @@
1414
raise ValueError('Requested device "cuda:0" but cuda is not available.')
1515

1616
DEVICE = torch.device(_device_str)
17+
18+
19+
_POSSIBLE_TEST_DTYPES = {"float32", "float64"}
20+
21+
try:
22+
_dtype_str = os.environ["PYTEST_TORCH_DTYPE"]
23+
except KeyError:
24+
_dtype_str = "float32" # Default to float32 if environment variable not set
25+
26+
if _dtype_str not in _POSSIBLE_TEST_DTYPES:
27+
raise ValueError(
28+
f"Invalid value of environment variable PYTEST_TORCH_DTYPE: {_dtype_str}.\n"
29+
f"Possible values: {_POSSIBLE_TEST_DTYPES}."
30+
)
31+
32+
DTYPE = getattr(torch, _dtype_str) # "float32" => torch.float32

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import gc
22

33
import torch
4-
from device import DEVICE
4+
from settings import DEVICE
55
from utils.architectures import (
66
AlexNet,
77
Cifar10Model,

tests/unit/aggregation/_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from device import DEVICE
2+
from settings import DEVICE
33
from utils.tensors import zeros_
44

55
from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler

tests/unit/aggregation/_matrix_samplers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
import torch
4+
from settings import DTYPE
45
from torch import Tensor
56
from torch.nn.functional import normalize
67
from utils.tensors import randint_, randn_, randperm_, zeros_
@@ -9,7 +10,7 @@
910
class MatrixSampler(ABC):
1011
"""Abstract base class for sampling matrices of a given shape, rank and dtype."""
1112

12-
def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = torch.float32):
13+
def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = DTYPE):
1314
self._check_params(m, n, rank, dtype)
1415
self.m = m
1516
self.n = n

tests/unit/aggregation/test_dualproj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_expected_structure(aggregator: DualProj, matrix: Tensor):
2727

2828
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
2929
def test_non_conflicting(aggregator: DualProj, matrix: Tensor):
30-
assert_non_conflicting(aggregator, matrix, atol=5e-05, rtol=5e-05)
30+
assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04)
3131

3232

3333
@mark.parametrize(["aggregator", "matrix"], typical_pairs)

tests/unit/aggregation/test_upgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_expected_structure(aggregator: UPGrad, matrix: Tensor):
2828

2929
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
3030
def test_non_conflicting(aggregator: UPGrad, matrix: Tensor):
31-
assert_non_conflicting(aggregator, matrix, atol=3e-04, rtol=3e-04)
31+
assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04)
3232

3333

3434
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
@@ -38,7 +38,7 @@ def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor):
3838

3939
@mark.parametrize(["aggregator", "matrix"], typical_pairs)
4040
def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor):
41-
assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=3e-02, rtol=3e-02)
41+
assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=4e-02, rtol=4e-02)
4242

4343

4444
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import math
22

33
import torch
4-
from device import DEVICE
54
from pytest import mark, raises
5+
from settings import DEVICE
66
from utils.dict_assertions import assert_tensor_dicts_are_close
77
from utils.tensors import rand_, tensor_, zeros_
88

tests/unit/autojac/test_mtl_backward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from pytest import mark, raises
3+
from settings import DTYPE
34
from torch.autograd import grad
45
from torch.testing import assert_close
56
from utils.tensors import arange_, rand_, randn_, tensor_
@@ -345,7 +346,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]):
345346
"""Tests that mtl_backward works correctly with various kinds of feature lists."""
346347

347348
p0 = tensor_([1.0, 2.0], requires_grad=True)
348-
p1 = arange_(len(shapes), dtype=torch.float32, requires_grad=True)
349+
p1 = arange_(len(shapes), dtype=DTYPE, requires_grad=True)
349350
p2 = tensor_(5.0, requires_grad=True)
350351

351352
features = [rand_(shape) @ p0 for shape in shapes]

tests/unit/autojac/test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from device import DEVICE
21
from pytest import mark, raises
2+
from settings import DEVICE, DTYPE
33
from torch.nn import Linear, MSELoss, ReLU, Sequential
44
from utils.tensors import randn_, tensor_
55

@@ -85,7 +85,7 @@ def test_get_leaf_tensors_model():
8585
x = randn_(16, 10)
8686
y = randn_(16, 1)
8787

88-
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1)).to(device=DEVICE)
88+
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1)).to(device=DEVICE, dtype=DTYPE)
8989
loss_fn = MSELoss(reduction="none")
9090

9191
y_hat = model(x)
@@ -104,8 +104,8 @@ def test_get_leaf_tensors_model_excluded_2():
104104
x = randn_(16, 10)
105105
z = randn_(16, 1)
106106

107-
model1 = Sequential(Linear(10, 5), ReLU()).to(device=DEVICE)
108-
model2 = Linear(5, 1).to(device=DEVICE)
107+
model1 = Sequential(Linear(10, 5), ReLU()).to(device=DEVICE, dtype=DTYPE)
108+
model2 = Linear(5, 1).to(device=DEVICE, dtype=DTYPE)
109109
loss_fn = MSELoss(reduction="none")
110110

111111
y = model1(x)

0 commit comments

Comments
 (0)