From 34d4cb52c9c104050dccb46e8ac4d87392ba0ea2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 7 Apr 2026 21:03:03 +0800 Subject: [PATCH 1/2] test: add .pte and .pt2 tests for dp convert-backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add pt_expt backend (.pte/.pt2) to the parameterized extensions in test_models.py to verify convert-backend works for the new exportable formats. The fparam_aparam model (1 atom type) is switched from type_one_side=False to type_one_side=True (with ndim 2→1), which is equivalent for single-type models but enables pt_expt export. Models with type_one_side=False and multiple types (se_e2_a, se_e2_r) are skipped for .pte/.pt2 as make_fx cannot trace data-dependent indexing in NetworkCollection(ndim=2). --- .../tests/infer/fparam_aparam-testcase.yaml | 2 +- source/tests/infer/fparam_aparam.yaml | 6 ++--- source/tests/infer/fparam_aparam_default.yaml | 6 ++--- source/tests/infer/test_models.py | 26 ++++++++++++------- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/source/tests/infer/fparam_aparam-testcase.yaml b/source/tests/infer/fparam_aparam-testcase.yaml index 220b2df209..1f300e31dd 100644 --- a/source/tests/infer/fparam_aparam-testcase.yaml +++ b/source/tests/infer/fparam_aparam-testcase.yaml @@ -26,7 +26,7 @@ model_def_script: "set_davg_zero": False, "trainable": True, "type": "se_e2_a", - "type_one_side": False, + "type_one_side": True, }, "fitting_net": { diff --git a/source/tests/infer/fparam_aparam.yaml b/source/tests/infer/fparam_aparam.yaml index e0654e142f..3f22bbcf6c 100644 --- a/source/tests/infer/fparam_aparam.yaml +++ b/source/tests/infer/fparam_aparam.yaml @@ -526,7 +526,7 @@ model: embeddings: "@class": NetworkCollection "@version": 1 - ndim: 2 + ndim: 1 network_type: embedding_network networks: - "@class": EmbeddingNetwork @@ -916,7 +916,7 @@ model: type: se_e2_a type_map: &id001 - O - type_one_side: false + type_one_side: true fitting: "@class": Fitting "@variables": @@ -2012,7 +2012,7 @@ model_def_script: set_davg_zero: false trainable: true type: se_e2_a - type_one_side: false + type_one_side: true fitting_net: activation_function: tanh atom_ener: *id004 diff --git a/source/tests/infer/fparam_aparam_default.yaml b/source/tests/infer/fparam_aparam_default.yaml index 6d64bfc328..5798817325 100644 --- a/source/tests/infer/fparam_aparam_default.yaml +++ b/source/tests/infer/fparam_aparam_default.yaml @@ -526,7 +526,7 @@ model: embeddings: "@class": NetworkCollection "@version": 1 - ndim: 2 + ndim: 1 network_type: embedding_network networks: - "@class": EmbeddingNetwork @@ -916,7 +916,7 @@ model: type: se_e2_a type_map: &id001 - O - type_one_side: false + type_one_side: true fitting: "@class": Fitting "@variables": @@ -2021,7 +2021,7 @@ model_def_script: set_davg_zero: false trainable: true type: se_e2_a - type_one_side: false + type_one_side: true fitting_net: activation_function: tanh atom_ener: *id004 diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 7f7b7cc21c..500622c664 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -13,6 +13,7 @@ ) from ..consistent.common import ( + INSTALLED_PT_EXPT, parameterized, ) from .case import ( @@ -28,7 +29,7 @@ "se_e2_r", "fparam_aparam", ), # key - (".pb", ".pth"), # model extension + (".pb", ".pth", ".pte", ".pt2"), # model extension ) class TestDeepPot(unittest.TestCase): # moved from tests/tf/test_deeppot_a.py @@ -36,6 +37,16 @@ class TestDeepPot(unittest.TestCase): @classmethod def setUpClass(cls) -> None: key, extension = cls.param + if extension in (".pte", ".pt2") and not INSTALLED_PT_EXPT: + raise unittest.SkipTest("pt_expt backend not installed") + if key in ("se_e2_a", "se_e2_r") and extension in (".pte", ".pt2"): + raise unittest.SkipTest( + "type_one_side=False is not supported for pt_expt export" + ) + if key == "se_e2_r" and extension == ".pth": + raise unittest.SkipTest( + "se_e2_r type_one_side is not supported for PyTorch models" + ) cls.case = get_cases()[key] cls.model_name = cls.case.get_model(extension) cls.dp = DeepEval(cls.model_name) @@ -44,13 +55,6 @@ def setUpClass(cls) -> None: def tearDownClass(cls) -> None: cls.dp = None - def setUp(self) -> None: - key, extension = self.param - if key == "se_e2_r" and extension == ".pth": - self.skipTest( - reason="se_e2_r type_one_side is not supported for PyTorch models" - ) - def test_attrs(self) -> None: assert isinstance(self.dp, DeepPot) self.assertEqual(self.dp.get_ntypes(), self.case.ntypes) @@ -153,6 +157,8 @@ def test_1frame_atm(self) -> None: def test_descriptor(self) -> None: _, extension = self.param + if extension in (".pte", ".pt2"): + self.skipTest("eval_descriptor not supported for pt_expt models") for ii, result in enumerate(self.case.results): if result.descriptor is None: continue @@ -166,8 +172,8 @@ def test_descriptor(self) -> None: def test_fitting_last_layer(self) -> None: _, extension = self.param - if extension == ".pb": - self.skipTest("fitting_last_layer not supported for TensorFlow models") + if extension in (".pb", ".pte", ".pt2"): + self.skipTest("fitting_last_layer not supported for this backend") for ii, result in enumerate(self.case.results): if result.fit_ll is None: continue From 281989d1589427a8dc2cddfbf90f5890f9ef7b19 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 8 Apr 2026 08:24:08 +0800 Subject: [PATCH 2/2] fix: reset default device before .pt2 AOTInductor compilation tests/pt/__init__.py may set a fake default device for CPU fallback, which poisons AOTInductor compilation. Temporarily clear the default device before converting to .pt2, matching the pattern used in test_change_bias.py. --- source/tests/infer/test_models.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 500622c664..44e3de30cb 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -48,7 +48,19 @@ def setUpClass(cls) -> None: "se_e2_r type_one_side is not supported for PyTorch models" ) cls.case = get_cases()[key] - cls.model_name = cls.case.get_model(extension) + if extension == ".pt2": + import torch + + # Clear default device: tests/pt/__init__.py may set a fake + # device for CPU fallback, which poisons AOTInductor compilation. + saved_device = torch.get_default_device() + torch.set_default_device(None) + try: + cls.model_name = cls.case.get_model(extension) + finally: + torch.set_default_device(saved_device) + else: + cls.model_name = cls.case.get_model(extension) cls.dp = DeepEval(cls.model_name) @classmethod