Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions source/tests/common/test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Shared test mixins and utilities used by both pt and pt_expt tests.

These are kept in ``tests.common`` so that importing them does NOT trigger
``tests.pt.__init__`` (which sets ``torch.set_default_device("cuda:9999999")``
and pushes a ``DeviceContext`` onto the mode stack, breaking pt_expt tests
on CPU-only machines).
"""

import numpy as np


class TestCaseSingleFrameWithNlist:
"""Mixin providing a small 2-frame, 2-type test system with neighbor list."""

def setUp(self) -> None:
# nloc == 3, nall == 4
self.nloc = 3
self.nall = 4
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
self.sel_mix = [7]
self.natoms = [3, 3, 2, 1]
self.nlist = np.array(
[
[1, 3, -1, -1, -1, 2, -1],
[0, -1, -1, -1, -1, 2, -1],
[0, 1, -1, -1, -1, -1, -1],
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 2.2
self.rcut_smth = 0.4
# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
self.mapping = np.concatenate(
[self.mapping, self.mapping[:, self.perm]], axis=0
)

# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
self.atol = 1e-12


def get_tols(prec):
"""Return (rtol, atol) for a given precision string."""
if prec in ["single", "float32"]:
rtol, atol = 0.0, 1e-4
elif prec in ["double", "float64"]:
rtol, atol = 0.0, 1e-12
else:
raise ValueError(f"unknown prec {prec}")
return rtol, atol
55 changes: 3 additions & 52 deletions source/tests/pt/model/test_env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,58 +21,9 @@
dtype = env.GLOBAL_PT_FLOAT_PRECISION


class TestCaseSingleFrameWithNlist:
def setUp(self) -> None:
# nloc == 3, nall == 4
self.nloc = 3
self.nall = 4
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
self.sel_mix = [7]
self.natoms = [3, 3, 2, 1]
self.nlist = np.array(
[
[1, 3, -1, -1, -1, 2, -1],
[0, -1, -1, -1, -1, 2, -1],
[0, 1, -1, -1, -1, -1, -1],
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 2.2
self.rcut_smth = 0.4
# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
self.mapping = np.concatenate(
[self.mapping, self.mapping[:, self.perm]], axis=0
)

# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
self.atol = 1e-12
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)


class TestCaseSingleFrameWithNlistWithVirtual:
Expand Down
14 changes: 3 additions & 11 deletions source/tests/pt/model/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,9 @@
PRECISION_DICT,
)


def get_tols(prec):
if prec in ["single", "float32"]:
rtol, atol = 0.0, 1e-4
elif prec in ["double", "float64"]:
rtol, atol = 0.0, 1e-12
# elif prec in ["half", "float16"]:
# rtol, atol=1e-2, 0
else:
raise ValueError(f"unknown prec {prec}")
return rtol, atol
from ...common.test_mixins import (
get_tols,
)


class TestMLPLayer(unittest.TestCase):
Expand Down
43 changes: 11 additions & 32 deletions source/tests/pt_expt/conftest.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Conftest for pt_expt tests.

Clears any leaked ``torch.utils._device.DeviceContext`` modes that may
have been left on the torch function mode stack by ``make_fx`` or other
tracing utilities during test collection. A stale ``DeviceContext``
silently reroutes ``torch.tensor(...)`` calls (without an explicit
``device=``) to a fake CUDA device, causing spurious "no NVIDIA driver"
errors on CPU-only machines.

The leak is triggered when pytest collects descriptor test modules that
import ``make_fx``. A ``DeviceContext(cuda:127)`` ends up on the
``torch.overrides`` function mode stack and is never popped.

Our own code (``display_if_exist`` in ``deepmd/dpmodel/loss/loss.py``)
is already fixed to pass ``device=`` explicitly. However, PyTorch's
``Adam._init_group`` (``torch/optim/adam.py``) contains::

torch.tensor(0.0, dtype=_get_scalar_dtype()) # no device=

on the ``capturable=False, fused=False`` path (the default). This is
a PyTorch bug — the ``capturable=True`` branch correctly uses
``device=p.device`` but the default branch omits it. We cannot fix
PyTorch internals, so this fixture works around the issue by popping
leaked ``DeviceContext`` modes before each test.
Safety net: pops any leaked ``torch.utils._device.DeviceContext`` modes
from the torch function mode stack before each test.

The primary leak source was ``source/tests/pt/__init__.py`` which calls
``torch.set_default_device("cuda:9999999")``; pt_expt tests previously
imported shared mixins from ``tests.pt.model``, triggering that init.
This was fixed by moving the shared mixins to ``tests.common.test_mixins``
so pt_expt tests no longer import from the ``tests.pt`` package.
"""

import pytest
Expand Down Expand Up @@ -50,20 +36,13 @@ def _pop_device_contexts() -> list:

@pytest.fixture(autouse=True, scope="session")
def _clear_leaked_device_context_session():
"""Pop any stale DeviceContext once at session start.

This runs before any setUpClass, preventing CUDA init errors
in tests that call trainer.run() during class setup.
"""
"""Pop any stale DeviceContext once at session start."""
_pop_device_contexts()
yield


@pytest.fixture(autouse=True)
def _clear_leaked_device_context():
"""Pop any stale ``DeviceContext`` before each test, restore after."""
popped = _pop_device_contexts()
"""Pop any stale ``DeviceContext`` before each test (safety net)."""
_pop_device_contexts()
Comment thread
wanghan-iapcm marked this conversation as resolved.
yield
# Restore in reverse order so the stack is back to its original state.
for ctx in reversed(popped):
ctx.__enter__()
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
PairExcludeMask,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
4 changes: 1 addition & 3 deletions source/tests/pt_expt/descriptor/test_se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
PRECISION_DICT,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...pt.model.test_mlp import (
get_tols,
)
from ...seed import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt_expt/fitting/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
env,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...seed import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt_expt/fitting/test_dos_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
env,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...seed import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt_expt/fitting/test_ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
env,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...seed import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt_expt/fitting/test_invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
env,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...seed import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt_expt/fitting/test_polar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
env,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...seed import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt_expt/fitting/test_property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
env,
)

from ...pt.model.test_env_mat import (
from ...common.test_mixins import (
TestCaseSingleFrameWithNlist,
)
from ...seed import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt_expt/loss/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
PRECISION_DICT,
)

from ...pt.model.test_mlp import (
from ...common.test_mixins import (
get_tols,
)
from ...seed import (
Expand Down
Loading
Loading