@@ -169,6 +169,26 @@ def _clear_default_device() -> Iterator[None]:
169169 torch .set_default_device (saved )
170170
171171
172+ class _ClearDefaultDeviceTestCase (unittest .TestCase ):
173+ """Run a test class while the pt default-device sentinel is disabled."""
174+
175+ @classmethod
176+ def setUpClass (cls ) -> None :
177+ super ().setUpClass ()
178+ cls ._default_device_ctx = _clear_default_device ()
179+ cls ._default_device_ctx .__enter__ ()
180+
181+ @classmethod
182+ def tearDownClass (cls ) -> None :
183+ try :
184+ super ().tearDownClass ()
185+ finally :
186+ ctx = getattr (cls , "_default_device_ctx" , None )
187+ if ctx is not None :
188+ ctx .__exit__ (None , None , None )
189+ delattr (cls , "_default_device_ctx" )
190+
191+
172192def _eager_forward (
173193 model : torch .nn .Module ,
174194 sample_inputs : tuple ,
@@ -189,7 +209,7 @@ def _eager_forward(
189209 )
190210
191211
192- class TestSeZMExportPipeline (unittest . TestCase ):
212+ class TestSeZMExportPipeline (_ClearDefaultDeviceTestCase ):
193213 """Bitwise trace / export / ``.pte`` round-trip parity (``rtol=1e-10``).
194214
195215 The ExportedProgram is a pure FX graph (no Inductor codegen), so
@@ -201,23 +221,28 @@ class TestSeZMExportPipeline(unittest.TestCase):
201221
202222 @classmethod
203223 def setUpClass (cls ) -> None :
204- with _clear_default_device ():
224+ super ().setUpClass ()
225+ try :
205226 cls .model = _build_tiny_sezm_model ()
206227 cls .sample_inputs = _make_sample (cls .model , nloc = 7 , start = 2 )
207228 cls .traced , cls .loaded , cls ._pte_tmp = cls ._build_pipeline (
208229 cls .model , cls .sample_inputs
209230 )
231+ except Exception :
232+ super ().tearDownClass ()
233+ raise
210234
211235 @classmethod
212236 def tearDownClass (cls ) -> None :
213- cls ._pte_tmp .close ()
214-
215- def setUp (self ) -> None :
216- self ._device_ctx = _clear_default_device ()
217- self ._device_ctx .__enter__ ()
218-
219- def tearDown (self ) -> None :
220- self ._device_ctx .__exit__ (None , None , None )
237+ try :
238+ for attr in ("loaded" , "traced" , "model" , "sample_inputs" ):
239+ if hasattr (cls , attr ):
240+ delattr (cls , attr )
241+ if hasattr (cls , "_pte_tmp" ):
242+ cls ._pte_tmp .close ()
243+ delattr (cls , "_pte_tmp" )
244+ finally :
245+ super ().tearDownClass ()
221246
222247 @staticmethod
223248 def _build_pipeline (
@@ -302,7 +327,7 @@ def test_loaded_pte_matches_eager_different_shape(self) -> None:
302327 )
303328
304329
305- class _FrozenPt2Fixture :
330+ class _FrozenPt2Fixture ( _ClearDefaultDeviceTestCase ) :
306331 """Shared setUp/tearDown: freeze a tiny SeZM checkpoint to ``.pt2`` once.
307332
308333 AOTInductor compilation costs a few seconds; classes that share this
@@ -314,29 +339,39 @@ class _FrozenPt2Fixture:
314339 ckpt_path : Path
315340 out_path : Path
316341
342+ @classmethod
343+ def _cleanup_frozen_fixture (cls ) -> None :
344+ if hasattr (cls , "_tmpdir" ):
345+ cls ._tmpdir .cleanup ()
346+ delattr (cls , "_tmpdir" )
347+ for attr in ("params" , "ckpt_path" , "out_path" ):
348+ if hasattr (cls , attr ):
349+ delattr (cls , attr )
350+
317351 @classmethod
318352 def setUpClass (cls ) -> None :
353+ super ().setUpClass ()
319354 cls ._tmpdir = tempfile .TemporaryDirectory ()
320- tmp_root = Path ( cls . _tmpdir . name )
321- cls . params = _tiny_sezm_model_params ( )
322- with _clear_default_device ():
355+ try :
356+ tmp_root = Path ( cls . _tmpdir . name )
357+ cls . params = _tiny_sezm_model_params ()
323358 cls .ckpt_path = _write_tiny_sezm_checkpoint (tmp_root , cls .params )
324359 cls .out_path = tmp_root / "frozen_sezm.pt2"
325360 freeze_sezm_to_pt2 (str (cls .ckpt_path ), str (cls .out_path ), device = _CPU )
361+ except Exception :
362+ cls ._cleanup_frozen_fixture ()
363+ super ().tearDownClass ()
364+ raise
326365
327366 @classmethod
328367 def tearDownClass (cls ) -> None :
329- cls ._tmpdir .cleanup ()
330-
331- def setUp (self ) -> None :
332- self ._device_ctx = _clear_default_device ()
333- self ._device_ctx .__enter__ ()
368+ try :
369+ cls ._cleanup_frozen_fixture ()
370+ finally :
371+ super ().tearDownClass ()
334372
335- def tearDown (self ) -> None :
336- self ._device_ctx .__exit__ (None , None , None )
337373
338-
339- class TestSeZMExportArchive (_FrozenPt2Fixture , unittest .TestCase ):
374+ class TestSeZMExportArchive (_FrozenPt2Fixture ):
340375 """AOTI ``.pt2`` archive structure + load-and-run smoke.
341376
342377 Numerical parity of the compiled ``.pt2`` is covered by the
@@ -427,7 +462,7 @@ def test_aoti_load_and_run_returns_finite_outputs(self) -> None:
427462 self .assertTrue (torch .isfinite (out_map [key ]).all ().item ())
428463
429464
430- class TestSeZMViaDeepPot (_FrozenPt2Fixture , unittest . TestCase ):
465+ class TestSeZMViaDeepPot (_FrozenPt2Fixture ):
431466 """Integration through the standard :class:`deepmd.infer.DeepPot` entry.
432467
433468 Locks in the contract that makes ``dp test -m frozen.pt2`` and the
@@ -448,48 +483,60 @@ class TestSeZMViaDeepPot(_FrozenPt2Fixture, unittest.TestCase):
448483
449484 @classmethod
450485 def setUpClass (cls ) -> None :
451- # The ``.pt2`` archive is compiled on CPU by the fixture; AOTI
452- # packages are device-locked, so ``pt_expt.DeepEval``'s input
453- # preparation must also place tensors on CPU — otherwise
454- # ``_pt2_runner(...)`` segfaults on dtype/device mismatch.
455- # ``_prepare_inputs`` does a function-local
456- # ``from deepmd.pt_expt.utils.env import DEVICE``, so patching
457- # the module attribute is enough (no rebinding required).
458- import deepmd .pt_expt .utils .env as _pt_expt_env
459-
460- cls ._orig_pt_expt_device = _pt_expt_env .DEVICE
461- _pt_expt_env .DEVICE = _CPU
462-
463486 super ().setUpClass ()
487+ try :
488+ # The ``.pt2`` archive is compiled on CPU by the fixture; AOTI
489+ # packages are device-locked, so ``pt_expt.DeepEval``'s input
490+ # preparation must also place tensors on CPU — otherwise
491+ # ``_pt2_runner(...)`` segfaults on dtype/device mismatch.
492+ # ``_prepare_inputs`` does a function-local
493+ # ``from deepmd.pt_expt.utils.env import DEVICE``, so patching
494+ # the module attribute is enough (no rebinding required).
495+ import deepmd .pt_expt .utils .env as _pt_expt_env
496+
497+ cls ._orig_pt_expt_device = _pt_expt_env .DEVICE
498+ _pt_expt_env .DEVICE = _CPU
499+
500+ # Late import: building the deepmd Backend registry is cheap, but
501+ # doing it at collection time conflicts with the conftest
502+ # default-device sentinel used elsewhere in this package.
503+ from deepmd .infer import (
504+ DeepPot ,
505+ )
464506
465- # Late import: building the deepmd Backend registry is cheap, but
466- # doing it at collection time conflicts with the conftest
467- # default-device sentinel used elsewhere in this package.
468- from deepmd .infer import (
469- DeepPot ,
470- )
471-
472- cls .dp = DeepPot (str (cls .out_path ))
473-
474- # A deterministic bulk sample; coord is centred in a cubic box
475- # well inside the periodic image, and the atype distribution
476- # exercises both type-0 and type-1 slots of sel=[2, 2].
477- rng = np .random .default_rng (2026 )
478- cls .natoms = 5
479- cls .atype = np .array ([0 , 1 , 0 , 1 , 0 ], dtype = np .int32 )
480- box_edge = cls .params ["descriptor" ]["rcut" ] * 3.0
481- cls .coord = (
482- rng .random ((1 , cls .natoms , 3 ), dtype = np .float64 ) * box_edge * 0.4
483- + box_edge * 0.3
484- )
485- cls .cell = (np .eye (3 , dtype = np .float64 ) * box_edge ).reshape (1 , 9 )
507+ cls .dp = DeepPot (str (cls .out_path ))
508+
509+ # A deterministic bulk sample; coord is centred in a cubic box
510+ # well inside the periodic image, and the atype distribution
511+ # exercises both type-0 and type-1 slots of sel=[2, 2].
512+ rng = np .random .default_rng (2026 )
513+ cls .natoms = 5
514+ cls .atype = np .array ([0 , 1 , 0 , 1 , 0 ], dtype = np .int32 )
515+ box_edge = cls .params ["descriptor" ]["rcut" ] * 3.0
516+ cls .coord = (
517+ rng .random ((1 , cls .natoms , 3 ), dtype = np .float64 ) * box_edge * 0.4
518+ + box_edge * 0.3
519+ )
520+ cls .cell = (np .eye (3 , dtype = np .float64 ) * box_edge ).reshape (1 , 9 )
521+ except Exception :
522+ super ().tearDownClass ()
523+ raise
486524
487525 @classmethod
488526 def tearDownClass (cls ) -> None :
489527 import deepmd .pt_expt .utils .env as _pt_expt_env
490528
491- _pt_expt_env .DEVICE = cls ._orig_pt_expt_device
492- super ().tearDownClass ()
529+ try :
530+ if hasattr (cls , "dp" ):
531+ delattr (cls , "dp" )
532+ if hasattr (cls , "_orig_pt_expt_device" ):
533+ _pt_expt_env .DEVICE = cls ._orig_pt_expt_device
534+ delattr (cls , "_orig_pt_expt_device" )
535+ for attr in ("natoms" , "atype" , "coord" , "cell" ):
536+ if hasattr (cls , attr ):
537+ delattr (cls , attr )
538+ finally :
539+ super ().tearDownClass ()
493540
494541 def _eager_energy_force_virial (self ) -> tuple [np .ndarray , ...]:
495542 """Run the eager SeZMModel forward and return arrays shaped like DeepPot."""
@@ -581,7 +628,7 @@ def test_deeppot_eval_atomic_matches_eager(self) -> None:
581628 )
582629
583630
584- class TestSeZMFreezeGuards (unittest . TestCase ):
631+ class TestSeZMFreezeGuards (_ClearDefaultDeviceTestCase ):
585632 """Error paths: detector rejections and CLI-level ``NotImplementedError``s."""
586633
587634 def test_metadata_records_ntypes_when_type_map_is_empty (self ) -> None :
0 commit comments