Skip to content

Commit 2edfd4f

Browse files
committed
fix ut
1 parent 478cba6 commit 2edfd4f

1 file changed

Lines changed: 107 additions & 60 deletions

File tree

source/tests/pt/model/test_sezm_export.py

Lines changed: 107 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,26 @@ def _clear_default_device() -> Iterator[None]:
169169
torch.set_default_device(saved)
170170

171171

172+
class _ClearDefaultDeviceTestCase(unittest.TestCase):
173+
"""Run a test class while the pt default-device sentinel is disabled."""
174+
175+
@classmethod
176+
def setUpClass(cls) -> None:
177+
super().setUpClass()
178+
cls._default_device_ctx = _clear_default_device()
179+
cls._default_device_ctx.__enter__()
180+
181+
@classmethod
182+
def tearDownClass(cls) -> None:
183+
try:
184+
super().tearDownClass()
185+
finally:
186+
ctx = getattr(cls, "_default_device_ctx", None)
187+
if ctx is not None:
188+
ctx.__exit__(None, None, None)
189+
delattr(cls, "_default_device_ctx")
190+
191+
172192
def _eager_forward(
173193
model: torch.nn.Module,
174194
sample_inputs: tuple,
@@ -189,7 +209,7 @@ def _eager_forward(
189209
)
190210

191211

192-
class TestSeZMExportPipeline(unittest.TestCase):
212+
class TestSeZMExportPipeline(_ClearDefaultDeviceTestCase):
193213
"""Bitwise trace / export / ``.pte`` round-trip parity (``rtol=1e-10``).
194214
195215
The ExportedProgram is a pure FX graph (no Inductor codegen), so
@@ -201,23 +221,28 @@ class TestSeZMExportPipeline(unittest.TestCase):
201221

202222
@classmethod
203223
def setUpClass(cls) -> None:
204-
with _clear_default_device():
224+
super().setUpClass()
225+
try:
205226
cls.model = _build_tiny_sezm_model()
206227
cls.sample_inputs = _make_sample(cls.model, nloc=7, start=2)
207228
cls.traced, cls.loaded, cls._pte_tmp = cls._build_pipeline(
208229
cls.model, cls.sample_inputs
209230
)
231+
except Exception:
232+
super().tearDownClass()
233+
raise
210234

211235
@classmethod
212236
def tearDownClass(cls) -> None:
213-
cls._pte_tmp.close()
214-
215-
def setUp(self) -> None:
216-
self._device_ctx = _clear_default_device()
217-
self._device_ctx.__enter__()
218-
219-
def tearDown(self) -> None:
220-
self._device_ctx.__exit__(None, None, None)
237+
try:
238+
for attr in ("loaded", "traced", "model", "sample_inputs"):
239+
if hasattr(cls, attr):
240+
delattr(cls, attr)
241+
if hasattr(cls, "_pte_tmp"):
242+
cls._pte_tmp.close()
243+
delattr(cls, "_pte_tmp")
244+
finally:
245+
super().tearDownClass()
221246

222247
@staticmethod
223248
def _build_pipeline(
@@ -302,7 +327,7 @@ def test_loaded_pte_matches_eager_different_shape(self) -> None:
302327
)
303328

304329

305-
class _FrozenPt2Fixture:
330+
class _FrozenPt2Fixture(_ClearDefaultDeviceTestCase):
306331
"""Shared setUp/tearDown: freeze a tiny SeZM checkpoint to ``.pt2`` once.
307332
308333
AOTInductor compilation costs a few seconds; classes that share this
@@ -314,29 +339,39 @@ class _FrozenPt2Fixture:
314339
ckpt_path: Path
315340
out_path: Path
316341

342+
@classmethod
343+
def _cleanup_frozen_fixture(cls) -> None:
344+
if hasattr(cls, "_tmpdir"):
345+
cls._tmpdir.cleanup()
346+
delattr(cls, "_tmpdir")
347+
for attr in ("params", "ckpt_path", "out_path"):
348+
if hasattr(cls, attr):
349+
delattr(cls, attr)
350+
317351
@classmethod
318352
def setUpClass(cls) -> None:
353+
super().setUpClass()
319354
cls._tmpdir = tempfile.TemporaryDirectory()
320-
tmp_root = Path(cls._tmpdir.name)
321-
cls.params = _tiny_sezm_model_params()
322-
with _clear_default_device():
355+
try:
356+
tmp_root = Path(cls._tmpdir.name)
357+
cls.params = _tiny_sezm_model_params()
323358
cls.ckpt_path = _write_tiny_sezm_checkpoint(tmp_root, cls.params)
324359
cls.out_path = tmp_root / "frozen_sezm.pt2"
325360
freeze_sezm_to_pt2(str(cls.ckpt_path), str(cls.out_path), device=_CPU)
361+
except Exception:
362+
cls._cleanup_frozen_fixture()
363+
super().tearDownClass()
364+
raise
326365

327366
@classmethod
328367
def tearDownClass(cls) -> None:
329-
cls._tmpdir.cleanup()
330-
331-
def setUp(self) -> None:
332-
self._device_ctx = _clear_default_device()
333-
self._device_ctx.__enter__()
368+
try:
369+
cls._cleanup_frozen_fixture()
370+
finally:
371+
super().tearDownClass()
334372

335-
def tearDown(self) -> None:
336-
self._device_ctx.__exit__(None, None, None)
337373

338-
339-
class TestSeZMExportArchive(_FrozenPt2Fixture, unittest.TestCase):
374+
class TestSeZMExportArchive(_FrozenPt2Fixture):
340375
"""AOTI ``.pt2`` archive structure + load-and-run smoke.
341376
342377
Numerical parity of the compiled ``.pt2`` is covered by the
@@ -427,7 +462,7 @@ def test_aoti_load_and_run_returns_finite_outputs(self) -> None:
427462
self.assertTrue(torch.isfinite(out_map[key]).all().item())
428463

429464

430-
class TestSeZMViaDeepPot(_FrozenPt2Fixture, unittest.TestCase):
465+
class TestSeZMViaDeepPot(_FrozenPt2Fixture):
431466
"""Integration through the standard :class:`deepmd.infer.DeepPot` entry.
432467
433468
Locks in the contract that makes ``dp test -m frozen.pt2`` and the
@@ -448,48 +483,60 @@ class TestSeZMViaDeepPot(_FrozenPt2Fixture, unittest.TestCase):
448483

449484
@classmethod
450485
def setUpClass(cls) -> None:
451-
# The ``.pt2`` archive is compiled on CPU by the fixture; AOTI
452-
# packages are device-locked, so ``pt_expt.DeepEval``'s input
453-
# preparation must also place tensors on CPU — otherwise
454-
# ``_pt2_runner(...)`` segfaults on dtype/device mismatch.
455-
# ``_prepare_inputs`` does a function-local
456-
# ``from deepmd.pt_expt.utils.env import DEVICE``, so patching
457-
# the module attribute is enough (no rebinding required).
458-
import deepmd.pt_expt.utils.env as _pt_expt_env
459-
460-
cls._orig_pt_expt_device = _pt_expt_env.DEVICE
461-
_pt_expt_env.DEVICE = _CPU
462-
463486
super().setUpClass()
487+
try:
488+
# The ``.pt2`` archive is compiled on CPU by the fixture; AOTI
489+
# packages are device-locked, so ``pt_expt.DeepEval``'s input
490+
# preparation must also place tensors on CPU — otherwise
491+
# ``_pt2_runner(...)`` segfaults on dtype/device mismatch.
492+
# ``_prepare_inputs`` does a function-local
493+
# ``from deepmd.pt_expt.utils.env import DEVICE``, so patching
494+
# the module attribute is enough (no rebinding required).
495+
import deepmd.pt_expt.utils.env as _pt_expt_env
496+
497+
cls._orig_pt_expt_device = _pt_expt_env.DEVICE
498+
_pt_expt_env.DEVICE = _CPU
499+
500+
# Late import: building the deepmd Backend registry is cheap, but
501+
# doing it at collection time conflicts with the conftest
502+
# default-device sentinel used elsewhere in this package.
503+
from deepmd.infer import (
504+
DeepPot,
505+
)
464506

465-
# Late import: building the deepmd Backend registry is cheap, but
466-
# doing it at collection time conflicts with the conftest
467-
# default-device sentinel used elsewhere in this package.
468-
from deepmd.infer import (
469-
DeepPot,
470-
)
471-
472-
cls.dp = DeepPot(str(cls.out_path))
473-
474-
# A deterministic bulk sample; coord is centred in a cubic box
475-
# well inside the periodic image, and the atype distribution
476-
# exercises both type-0 and type-1 slots of sel=[2, 2].
477-
rng = np.random.default_rng(2026)
478-
cls.natoms = 5
479-
cls.atype = np.array([0, 1, 0, 1, 0], dtype=np.int32)
480-
box_edge = cls.params["descriptor"]["rcut"] * 3.0
481-
cls.coord = (
482-
rng.random((1, cls.natoms, 3), dtype=np.float64) * box_edge * 0.4
483-
+ box_edge * 0.3
484-
)
485-
cls.cell = (np.eye(3, dtype=np.float64) * box_edge).reshape(1, 9)
507+
cls.dp = DeepPot(str(cls.out_path))
508+
509+
# A deterministic bulk sample; coord is centred in a cubic box
510+
# well inside the periodic image, and the atype distribution
511+
# exercises both type-0 and type-1 slots of sel=[2, 2].
512+
rng = np.random.default_rng(2026)
513+
cls.natoms = 5
514+
cls.atype = np.array([0, 1, 0, 1, 0], dtype=np.int32)
515+
box_edge = cls.params["descriptor"]["rcut"] * 3.0
516+
cls.coord = (
517+
rng.random((1, cls.natoms, 3), dtype=np.float64) * box_edge * 0.4
518+
+ box_edge * 0.3
519+
)
520+
cls.cell = (np.eye(3, dtype=np.float64) * box_edge).reshape(1, 9)
521+
except Exception:
522+
super().tearDownClass()
523+
raise
486524

487525
@classmethod
488526
def tearDownClass(cls) -> None:
489527
import deepmd.pt_expt.utils.env as _pt_expt_env
490528

491-
_pt_expt_env.DEVICE = cls._orig_pt_expt_device
492-
super().tearDownClass()
529+
try:
530+
if hasattr(cls, "dp"):
531+
delattr(cls, "dp")
532+
if hasattr(cls, "_orig_pt_expt_device"):
533+
_pt_expt_env.DEVICE = cls._orig_pt_expt_device
534+
delattr(cls, "_orig_pt_expt_device")
535+
for attr in ("natoms", "atype", "coord", "cell"):
536+
if hasattr(cls, attr):
537+
delattr(cls, attr)
538+
finally:
539+
super().tearDownClass()
493540

494541
def _eager_energy_force_virial(self) -> tuple[np.ndarray, ...]:
495542
"""Run the eager SeZMModel forward and return arrays shaped like DeepPot."""
@@ -581,7 +628,7 @@ def test_deeppot_eval_atomic_matches_eager(self) -> None:
581628
)
582629

583630

584-
class TestSeZMFreezeGuards(unittest.TestCase):
631+
class TestSeZMFreezeGuards(_ClearDefaultDeviceTestCase):
585632
"""Error paths: detector rejections and CLI-level ``NotImplementedError``s."""
586633

587634
def test_metadata_records_ntypes_when_type_map_is_empty(self) -> None:

0 commit comments

Comments
 (0)