Skip to content

Commit a7e9fed

Browse files
wanghan-iapcmHan Wang
andauthored
fix(tests): move shared test mixins to tests.common to fix DeviceContext leak (#5344)
## Summary - Move `TestCaseSingleFrameWithNlist` and `get_tols` from `tests.pt.model` to `tests.common.test_mixins` - Update all pt_expt test imports to use `tests.common.test_mixins` directly - Simplify `tests/pt_expt/conftest.py` and remove manual `_pop_device_contexts()` workarounds ## Root cause `tests/pt/__init__.py` calls `torch.set_default_device("cuda:9999999")` to enforce explicit device usage in pt tests. This pushes a `DeviceContext` onto the torch mode stack. pt_expt descriptor/fitting tests imported `TestCaseSingleFrameWithNlist` from `tests.pt.model.test_env_mat`, which triggered `tests/pt/__init__.py` — leaking the `DeviceContext` into pt_expt tests. The leaked `DeviceContext` caused `torch.zeros()` calls (without explicit `device=`) inside AOTInductor's lowering pass and PyTorch's Adam optimizer to target `cuda:9999999`, crashing on CPU-only CI machines with `AssertionError: Torch not compiled with CUDA enabled`. ## Fix The shared mixins (`TestCaseSingleFrameWithNlist`, `get_tols`) are pure numpy with no torch dependency. Moving them to `tests/common/test_mixins.py` lets pt_expt tests import them without touching the `tests.pt` package. The pt tests re-export from the common location for backward compatibility. ## Test plan - [x] `pytest source/tests/pt_expt/descriptor/test_se_e2_a.py` — passes, no DeviceContext leak - [x] `pytest source/tests/pt/model/test_env_mat.py` — passes (backward compat via re-export) - [x] `pytest source/tests/pt/model/test_mlp.py` — passes (backward compat via re-export) - [x] Broader pt_expt tests (descriptor, fitting, loss, utils) all pass - [x] Verified: `import source.tests.pt_expt.descriptor.test_se_e2_a` no longer creates DeviceContext 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Consolidated shared testing utilities into a centralized module referenced by many test suites for consistent setup and tolerance handling * Simplified and reduced device-context cleanup behavior in test fixtures * Updated numerous test imports to use the new shared testing utilities location <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 8fea5ab commit a7e9fed

22 files changed

+111
-136
lines changed

source/tests/common/test_mixins.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Shared test mixins and utilities used by both pt and pt_expt tests.
3+
4+
These are kept in ``tests.common`` so that importing them does NOT trigger
5+
``tests.pt.__init__`` (which sets ``torch.set_default_device("cuda:9999999")``
6+
and pushes a ``DeviceContext`` onto the mode stack, breaking pt_expt tests
7+
on CPU-only machines).
8+
"""
9+
10+
import numpy as np
11+
12+
13+
class TestCaseSingleFrameWithNlist:
14+
"""Mixin providing a small 2-frame, 2-type test system with neighbor list."""
15+
16+
def setUp(self) -> None:
17+
# nloc == 3, nall == 4
18+
self.nloc = 3
19+
self.nall = 4
20+
self.nf, self.nt = 2, 2
21+
self.coord_ext = np.array(
22+
[
23+
[0, 0, 0],
24+
[0, 1, 0],
25+
[0, 0, 1],
26+
[0, -2, 0],
27+
],
28+
dtype=np.float64,
29+
).reshape([1, self.nall, 3])
30+
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
31+
self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall])
32+
# sel = [5, 2]
33+
self.sel = [5, 2]
34+
self.sel_mix = [7]
35+
self.natoms = [3, 3, 2, 1]
36+
self.nlist = np.array(
37+
[
38+
[1, 3, -1, -1, -1, 2, -1],
39+
[0, -1, -1, -1, -1, 2, -1],
40+
[0, 1, -1, -1, -1, -1, -1],
41+
],
42+
dtype=int,
43+
).reshape([1, self.nloc, sum(self.sel)])
44+
self.rcut = 2.2
45+
self.rcut_smth = 0.4
46+
# permutations
47+
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
48+
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
49+
# permute the coord and atype
50+
self.coord_ext = np.concatenate(
51+
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
52+
).reshape(self.nf, self.nall * 3)
53+
self.atype_ext = np.concatenate(
54+
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
55+
)
56+
self.mapping = np.concatenate(
57+
[self.mapping, self.mapping[:, self.perm]], axis=0
58+
)
59+
60+
# permute the nlist
61+
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
62+
mask = nlist1 == -1
63+
nlist1 = inv_perm[nlist1]
64+
nlist1 = np.where(mask, -1, nlist1)
65+
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
66+
self.atol = 1e-12
67+
68+
69+
def get_tols(prec):
70+
"""Return (rtol, atol) for a given precision string."""
71+
if prec in ["single", "float32"]:
72+
rtol, atol = 0.0, 1e-4
73+
elif prec in ["double", "float64"]:
74+
rtol, atol = 0.0, 1e-12
75+
else:
76+
raise ValueError(f"unknown prec {prec}")
77+
return rtol, atol

source/tests/pt/model/test_env_mat.py

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,58 +21,9 @@
2121
dtype = env.GLOBAL_PT_FLOAT_PRECISION
2222

2323

24-
class TestCaseSingleFrameWithNlist:
25-
def setUp(self) -> None:
26-
# nloc == 3, nall == 4
27-
self.nloc = 3
28-
self.nall = 4
29-
self.nf, self.nt = 2, 2
30-
self.coord_ext = np.array(
31-
[
32-
[0, 0, 0],
33-
[0, 1, 0],
34-
[0, 0, 1],
35-
[0, -2, 0],
36-
],
37-
dtype=np.float64,
38-
).reshape([1, self.nall, 3])
39-
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
40-
self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall])
41-
# sel = [5, 2]
42-
self.sel = [5, 2]
43-
self.sel_mix = [7]
44-
self.natoms = [3, 3, 2, 1]
45-
self.nlist = np.array(
46-
[
47-
[1, 3, -1, -1, -1, 2, -1],
48-
[0, -1, -1, -1, -1, 2, -1],
49-
[0, 1, -1, -1, -1, -1, -1],
50-
],
51-
dtype=int,
52-
).reshape([1, self.nloc, sum(self.sel)])
53-
self.rcut = 2.2
54-
self.rcut_smth = 0.4
55-
# permutations
56-
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
57-
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
58-
# permute the coord and atype
59-
self.coord_ext = np.concatenate(
60-
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
61-
).reshape(self.nf, self.nall * 3)
62-
self.atype_ext = np.concatenate(
63-
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
64-
)
65-
self.mapping = np.concatenate(
66-
[self.mapping, self.mapping[:, self.perm]], axis=0
67-
)
68-
69-
# permute the nlist
70-
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
71-
mask = nlist1 == -1
72-
nlist1 = inv_perm[nlist1]
73-
nlist1 = np.where(mask, -1, nlist1)
74-
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
75-
self.atol = 1e-12
24+
from ...common.test_mixins import (
25+
TestCaseSingleFrameWithNlist,
26+
)
7627

7728

7829
class TestCaseSingleFrameWithNlistWithVirtual:

source/tests/pt/model/test_mlp.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,9 @@
2424
PRECISION_DICT,
2525
)
2626

27-
28-
def get_tols(prec):
29-
if prec in ["single", "float32"]:
30-
rtol, atol = 0.0, 1e-4
31-
elif prec in ["double", "float64"]:
32-
rtol, atol = 0.0, 1e-12
33-
# elif prec in ["half", "float16"]:
34-
# rtol, atol=1e-2, 0
35-
else:
36-
raise ValueError(f"unknown prec {prec}")
37-
return rtol, atol
27+
from ...common.test_mixins import (
28+
get_tols,
29+
)
3830

3931

4032
class TestMLPLayer(unittest.TestCase):

source/tests/pt_expt/conftest.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,14 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""Conftest for pt_expt tests.
33
4-
Clears any leaked ``torch.utils._device.DeviceContext`` modes that may
5-
have been left on the torch function mode stack by ``make_fx`` or other
6-
tracing utilities during test collection. A stale ``DeviceContext``
7-
silently reroutes ``torch.tensor(...)`` calls (without an explicit
8-
``device=``) to a fake CUDA device, causing spurious "no NVIDIA driver"
9-
errors on CPU-only machines.
10-
11-
The leak is triggered when pytest collects descriptor test modules that
12-
import ``make_fx``. A ``DeviceContext(cuda:127)`` ends up on the
13-
``torch.overrides`` function mode stack and is never popped.
14-
15-
Our own code (``display_if_exist`` in ``deepmd/dpmodel/loss/loss.py``)
16-
is already fixed to pass ``device=`` explicitly. However, PyTorch's
17-
``Adam._init_group`` (``torch/optim/adam.py``) contains::
18-
19-
torch.tensor(0.0, dtype=_get_scalar_dtype()) # no device=
20-
21-
on the ``capturable=False, fused=False`` path (the default). This is
22-
a PyTorch bug — the ``capturable=True`` branch correctly uses
23-
``device=p.device`` but the default branch omits it. We cannot fix
24-
PyTorch internals, so this fixture works around the issue by popping
25-
leaked ``DeviceContext`` modes before each test.
4+
Safety net: pops any leaked ``torch.utils._device.DeviceContext`` modes
5+
from the torch function mode stack before each test.
6+
7+
The primary leak source was ``source/tests/pt/__init__.py`` which calls
8+
``torch.set_default_device("cuda:9999999")``; pt_expt tests previously
9+
imported shared mixins from ``tests.pt.model``, triggering that init.
10+
This was fixed by moving the shared mixins to ``tests.common.test_mixins``
11+
so pt_expt tests no longer import from the ``tests.pt`` package.
2612
"""
2713

2814
import pytest
@@ -50,20 +36,13 @@ def _pop_device_contexts() -> list:
5036

5137
@pytest.fixture(autouse=True, scope="session")
5238
def _clear_leaked_device_context_session():
53-
"""Pop any stale DeviceContext once at session start.
54-
55-
This runs before any setUpClass, preventing CUDA init errors
56-
in tests that call trainer.run() during class setup.
57-
"""
39+
"""Pop any stale DeviceContext once at session start."""
5840
_pop_device_contexts()
5941
yield
6042

6143

6244
@pytest.fixture(autouse=True)
6345
def _clear_leaked_device_context():
64-
"""Pop any stale ``DeviceContext`` before each test, restore after."""
65-
popped = _pop_device_contexts()
46+
"""Pop any stale ``DeviceContext`` before each test (safety net)."""
47+
_pop_device_contexts()
6648
yield
67-
# Restore in reverse order so the stack is back to its original state.
68-
for ctx in reversed(popped):
69-
ctx.__enter__()

source/tests/pt_expt/descriptor/test_dpa1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
PRECISION_DICT,
1919
)
2020

21-
from ...pt.model.test_env_mat import (
21+
from ...common.test_mixins import (
2222
TestCaseSingleFrameWithNlist,
23-
)
24-
from ...pt.model.test_mlp import (
2523
get_tols,
2624
)
2725
from ...seed import (

source/tests/pt_expt/descriptor/test_dpa2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
PRECISION_DICT,
2323
)
2424

25-
from ...pt.model.test_env_mat import (
25+
from ...common.test_mixins import (
2626
TestCaseSingleFrameWithNlist,
27-
)
28-
from ...pt.model.test_mlp import (
2927
get_tols,
3028
)
3129
from ...seed import (

source/tests/pt_expt/descriptor/test_dpa3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
PRECISION_DICT,
2222
)
2323

24-
from ...pt.model.test_env_mat import (
24+
from ...common.test_mixins import (
2525
TestCaseSingleFrameWithNlist,
26-
)
27-
from ...pt.model.test_mlp import (
2826
get_tols,
2927
)
3028
from ...seed import (

source/tests/pt_expt/descriptor/test_hybrid.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@
2424
PRECISION_DICT,
2525
)
2626

27-
from ...pt.model.test_env_mat import (
27+
from ...common.test_mixins import (
2828
TestCaseSingleFrameWithNlist,
29-
)
30-
from ...pt.model.test_mlp import (
3129
get_tols,
3230
)
3331
from ...seed import (

source/tests/pt_expt/descriptor/test_se_atten_v2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
PRECISION_DICT,
1919
)
2020

21-
from ...pt.model.test_env_mat import (
21+
from ...common.test_mixins import (
2222
TestCaseSingleFrameWithNlist,
23-
)
24-
from ...pt.model.test_mlp import (
2523
get_tols,
2624
)
2725
from ...seed import (

source/tests/pt_expt/descriptor/test_se_e2_a.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
PairExcludeMask,
2222
)
2323

24-
from ...pt.model.test_env_mat import (
24+
from ...common.test_mixins import (
2525
TestCaseSingleFrameWithNlist,
26-
)
27-
from ...pt.model.test_mlp import (
2826
get_tols,
2927
)
3028
from ...seed import (

0 commit comments

Comments
 (0)