diff --git a/source/tests/common/test_mixins.py b/source/tests/common/test_mixins.py new file mode 100644 index 0000000000..e311baf5cf --- /dev/null +++ b/source/tests/common/test_mixins.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Shared test mixins and utilities used by both pt and pt_expt tests. + +These are kept in ``tests.common`` so that importing them does NOT trigger +``tests.pt.__init__`` (which sets ``torch.set_default_device("cuda:9999999")`` +and pushes a ``DeviceContext`` onto the mode stack, breaking pt_expt tests +on CPU-only machines). +""" + +import numpy as np + + +class TestCaseSingleFrameWithNlist: + """Mixin providing a small 2-frame, 2-type test system with neighbor list.""" + + def setUp(self) -> None: + # nloc == 3, nall == 4 + self.nloc = 3 + self.nall = 4 + self.nf, self.nt = 2, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall, 3]) + self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) + # sel = [5, 2] + self.sel = [5, 2] + self.sel_mix = [7] + self.natoms = [3, 3, 2, 1] + self.nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.rcut = 2.2 + self.rcut_smth = 0.4 + # permutations + self.perm = np.array([2, 0, 1, 3], dtype=np.int32) + inv_perm = np.array([1, 2, 0, 3], dtype=np.int32) + # permute the coord and atype + self.coord_ext = np.concatenate( + [self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0 + ).reshape(self.nf, self.nall * 3) + self.atype_ext = np.concatenate( + [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 + ) + self.mapping = np.concatenate( + [self.mapping, self.mapping[:, self.perm]], axis=0 + ) + + # permute the nlist + nlist1 = self.nlist[:, self.perm[: self.nloc], :] + mask = nlist1 == -1 + nlist1 = inv_perm[nlist1] + nlist1 = np.where(mask, -1, nlist1) + self.nlist = np.concatenate([self.nlist, nlist1], axis=0) + self.atol = 1e-12 + + +def get_tols(prec): + """Return (rtol, atol) for a given precision string.""" + if prec in ["single", "float32"]: + rtol, atol = 0.0, 1e-4 + elif prec in ["double", "float64"]: + rtol, atol = 0.0, 1e-12 + else: + raise ValueError(f"unknown prec {prec}") + return rtol, atol diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index 391f50da03..0c99311ca3 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -21,58 +21,9 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION -class TestCaseSingleFrameWithNlist: - def setUp(self) -> None: - # nloc == 3, nall == 4 - self.nloc = 3 - self.nall = 4 - self.nf, self.nt = 2, 2 - self.coord_ext = np.array( - [ - [0, 0, 0], - [0, 1, 0], - [0, 0, 1], - [0, -2, 0], - ], - dtype=np.float64, - ).reshape([1, self.nall, 3]) - self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) - self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) - # sel = [5, 2] - self.sel = [5, 2] - self.sel_mix = [7] - self.natoms = [3, 3, 2, 1] - self.nlist = np.array( - [ - [1, 3, -1, -1, -1, 2, -1], - [0, -1, -1, -1, -1, 2, -1], - [0, 1, -1, -1, -1, -1, -1], - ], - dtype=int, - ).reshape([1, self.nloc, sum(self.sel)]) - self.rcut = 2.2 - self.rcut_smth = 0.4 - # permutations - self.perm = np.array([2, 0, 1, 3], dtype=np.int32) - inv_perm = np.array([1, 2, 0, 3], dtype=np.int32) - # permute the coord and atype - self.coord_ext = np.concatenate( - [self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0 - ).reshape(self.nf, self.nall * 3) - self.atype_ext = np.concatenate( - [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 - ) - self.mapping = np.concatenate( - [self.mapping, self.mapping[:, self.perm]], axis=0 - ) - - # permute the nlist - nlist1 = self.nlist[:, self.perm[: self.nloc], :] - mask = nlist1 == -1 - nlist1 = inv_perm[nlist1] - nlist1 = np.where(mask, -1, nlist1) - self.nlist = np.concatenate([self.nlist, nlist1], axis=0) - self.atol = 1e-12 +from ...common.test_mixins import ( + TestCaseSingleFrameWithNlist, +) class TestCaseSingleFrameWithNlistWithVirtual: diff --git a/source/tests/pt/model/test_mlp.py b/source/tests/pt/model/test_mlp.py index e6dde660cb..5f067cb0e6 100644 --- a/source/tests/pt/model/test_mlp.py +++ b/source/tests/pt/model/test_mlp.py @@ -24,17 +24,9 @@ PRECISION_DICT, ) - -def get_tols(prec): - if prec in ["single", "float32"]: - rtol, atol = 0.0, 1e-4 - elif prec in ["double", "float64"]: - rtol, atol = 0.0, 1e-12 - # elif prec in ["half", "float16"]: - # rtol, atol=1e-2, 0 - else: - raise ValueError(f"unknown prec {prec}") - return rtol, atol +from ...common.test_mixins import ( + get_tols, +) class TestMLPLayer(unittest.TestCase): diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py index afda179a82..f2a9b07a6a 100644 --- a/source/tests/pt_expt/conftest.py +++ b/source/tests/pt_expt/conftest.py @@ -1,28 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Conftest for pt_expt tests. -Clears any leaked ``torch.utils._device.DeviceContext`` modes that may -have been left on the torch function mode stack by ``make_fx`` or other -tracing utilities during test collection. A stale ``DeviceContext`` -silently reroutes ``torch.tensor(...)`` calls (without an explicit -``device=``) to a fake CUDA device, causing spurious "no NVIDIA driver" -errors on CPU-only machines. - -The leak is triggered when pytest collects descriptor test modules that -import ``make_fx``. A ``DeviceContext(cuda:127)`` ends up on the -``torch.overrides`` function mode stack and is never popped. - -Our own code (``display_if_exist`` in ``deepmd/dpmodel/loss/loss.py``) -is already fixed to pass ``device=`` explicitly. However, PyTorch's -``Adam._init_group`` (``torch/optim/adam.py``) contains:: - - torch.tensor(0.0, dtype=_get_scalar_dtype()) # no device= - -on the ``capturable=False, fused=False`` path (the default). This is -a PyTorch bug — the ``capturable=True`` branch correctly uses -``device=p.device`` but the default branch omits it. We cannot fix -PyTorch internals, so this fixture works around the issue by popping -leaked ``DeviceContext`` modes before each test. +Safety net: pops any leaked ``torch.utils._device.DeviceContext`` modes +from the torch function mode stack before each test. + +The primary leak source was ``source/tests/pt/__init__.py`` which calls +``torch.set_default_device("cuda:9999999")``; pt_expt tests previously +imported shared mixins from ``tests.pt.model``, triggering that init. +This was fixed by moving the shared mixins to ``tests.common.test_mixins`` +so pt_expt tests no longer import from the ``tests.pt`` package. """ import pytest @@ -50,20 +36,13 @@ def _pop_device_contexts() -> list: @pytest.fixture(autouse=True, scope="session") def _clear_leaked_device_context_session(): - """Pop any stale DeviceContext once at session start. - - This runs before any setUpClass, preventing CUDA init errors - in tests that call trainer.run() during class setup. - """ + """Pop any stale DeviceContext once at session start.""" _pop_device_contexts() yield @pytest.fixture(autouse=True) def _clear_leaked_device_context(): - """Pop any stale ``DeviceContext`` before each test, restore after.""" - popped = _pop_device_contexts() + """Pop any stale ``DeviceContext`` before each test (safety net).""" + _pop_device_contexts() yield - # Restore in reverse order so the stack is back to its original state. - for ctx in reversed(popped): - ctx.__enter__() diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index 9827ca6679..0524c6c98a 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -18,10 +18,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_dpa2.py b/source/tests/pt_expt/descriptor/test_dpa2.py index 47ea4c0811..a3794052f4 100644 --- a/source/tests/pt_expt/descriptor/test_dpa2.py +++ b/source/tests/pt_expt/descriptor/test_dpa2.py @@ -22,10 +22,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_dpa3.py b/source/tests/pt_expt/descriptor/test_dpa3.py index 4aeec0dbad..7cdcd6ced7 100644 --- a/source/tests/pt_expt/descriptor/test_dpa3.py +++ b/source/tests/pt_expt/descriptor/test_dpa3.py @@ -21,10 +21,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_hybrid.py b/source/tests/pt_expt/descriptor/test_hybrid.py index 41a87273d5..3185733ddf 100644 --- a/source/tests/pt_expt/descriptor/test_hybrid.py +++ b/source/tests/pt_expt/descriptor/test_hybrid.py @@ -24,10 +24,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_se_atten_v2.py b/source/tests/pt_expt/descriptor/test_se_atten_v2.py index 5d71095af4..01ead9e179 100644 --- a/source/tests/pt_expt/descriptor/test_se_atten_v2.py +++ b/source/tests/pt_expt/descriptor/test_se_atten_v2.py @@ -18,10 +18,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py index 841ef776ff..0efb90b1bb 100644 --- a/source/tests/pt_expt/descriptor/test_se_e2_a.py +++ b/source/tests/pt_expt/descriptor/test_se_e2_a.py @@ -21,10 +21,8 @@ PairExcludeMask, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index c9aeaffac4..0bac2251b4 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -18,10 +18,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py index 80a32d03a0..3da6ef02f3 100644 --- a/source/tests/pt_expt/descriptor/test_se_t.py +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -18,10 +18,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py index 0f50e6000e..bb4b1dc80d 100644 --- a/source/tests/pt_expt/descriptor/test_se_t_tebd.py +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -18,10 +18,8 @@ PRECISION_DICT, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, -) -from ...pt.model.test_mlp import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/fitting/test_dipole_fitting.py b/source/tests/pt_expt/fitting/test_dipole_fitting.py index f5ac7ba177..959b23d6ea 100644 --- a/source/tests/pt_expt/fitting/test_dipole_fitting.py +++ b/source/tests/pt_expt/fitting/test_dipole_fitting.py @@ -17,7 +17,7 @@ env, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, ) from ...seed import ( diff --git a/source/tests/pt_expt/fitting/test_dos_fitting.py b/source/tests/pt_expt/fitting/test_dos_fitting.py index 3fe06a8618..340088e672 100644 --- a/source/tests/pt_expt/fitting/test_dos_fitting.py +++ b/source/tests/pt_expt/fitting/test_dos_fitting.py @@ -17,7 +17,7 @@ env, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, ) from ...seed import ( diff --git a/source/tests/pt_expt/fitting/test_ener_fitting.py b/source/tests/pt_expt/fitting/test_ener_fitting.py index 63ae82ab9a..fe55bd628a 100644 --- a/source/tests/pt_expt/fitting/test_ener_fitting.py +++ b/source/tests/pt_expt/fitting/test_ener_fitting.py @@ -14,7 +14,7 @@ env, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, ) from ...seed import ( diff --git a/source/tests/pt_expt/fitting/test_invar_fitting.py b/source/tests/pt_expt/fitting/test_invar_fitting.py index 30cbe84401..cf54ea500b 100644 --- a/source/tests/pt_expt/fitting/test_invar_fitting.py +++ b/source/tests/pt_expt/fitting/test_invar_fitting.py @@ -15,7 +15,7 @@ env, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, ) from ...seed import ( diff --git a/source/tests/pt_expt/fitting/test_polar_fitting.py b/source/tests/pt_expt/fitting/test_polar_fitting.py index 1c150f7154..e11b4455e5 100644 --- a/source/tests/pt_expt/fitting/test_polar_fitting.py +++ b/source/tests/pt_expt/fitting/test_polar_fitting.py @@ -17,7 +17,7 @@ env, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, ) from ...seed import ( diff --git a/source/tests/pt_expt/fitting/test_property_fitting.py b/source/tests/pt_expt/fitting/test_property_fitting.py index ca3dbc11af..19177be849 100644 --- a/source/tests/pt_expt/fitting/test_property_fitting.py +++ b/source/tests/pt_expt/fitting/test_property_fitting.py @@ -17,7 +17,7 @@ env, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, ) from ...seed import ( diff --git a/source/tests/pt_expt/loss/test_ener.py b/source/tests/pt_expt/loss/test_ener.py index 37d7d4c703..9bc81f119f 100644 --- a/source/tests/pt_expt/loss/test_ener.py +++ b/source/tests/pt_expt/loss/test_ener.py @@ -23,7 +23,7 @@ PRECISION_DICT, ) -from ...pt.model.test_mlp import ( +from ...common.test_mixins import ( get_tols, ) from ...seed import ( diff --git a/source/tests/pt_expt/test_change_bias.py b/source/tests/pt_expt/test_change_bias.py index 16974cc653..50d114af28 100644 --- a/source/tests/pt_expt/test_change_bias.py +++ b/source/tests/pt_expt/test_change_bias.py @@ -123,12 +123,6 @@ class TestChangeBias(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - from .conftest import ( - _pop_device_contexts, - ) - - _pop_device_contexts() - data_dir = os.path.join(EXAMPLE_DIR, "data") if not os.path.isdir(data_dir): raise unittest.SkipTest(f"Example data not found: {data_dir}") diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py index 6f836913af..cc0671c117 100644 --- a/source/tests/pt_expt/utils/test_exclusion_mask.py +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -12,7 +12,7 @@ PairExcludeMask, ) -from ...pt.model.test_env_mat import ( +from ...common.test_mixins import ( TestCaseSingleFrameWithNlist, )