Skip to content

Commit e5b0f21

Browse files
wanghan-iapcmHan Wang
andauthored
test(pt-expt): add .pte and .pt2 tests for dp convert-backend (#5384)
## Summary - Add `.pte` and `.pt2` to parameterized extensions in `test_models.py` to verify `convert-backend` works for the pt_expt exportable formats - Switch `fparam_aparam` model (1 atom type) from `type_one_side=False` to `True` (with `ndim` 2→1), which is equivalent for single-type models but enables pt_expt export - Skip `se_e2_a`/`se_e2_r` for `.pte`/`.pt2` — `type_one_side=False` with multiple types triggers `GuardOnDataDependentSymNode` in `make_fx` (data-dependent indexing in `NetworkCollection(ndim=2)`) - Skip `test_descriptor` and `test_fitting_last_layer` for `.pte`/`.pt2` (not implemented in pt_expt `DeepEval`) ## Test plan - [x] `python -m pytest source/tests/infer/test_models.py -v` — 55 passed, 67 skipped - [x] `python -m pytest source/tests/infer/test_models.py -v -k "pte or pt2"` — 12 passed (fparam_aparam), rest skipped - [x] Existing `.pb`/`.pth` tests unaffected <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Updated test case configurations with adjusted model descriptor and embedding parameters * Extended test suite parameterization to support additional model file formats (.pte, .pt2) * Reorganized model compatibility skip conditions for improved test structure and maintainability * Added special handling when loading certain model formats to ensure stable device behavior during tests <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent a9bc4a2 commit e5b0f21

4 files changed

Lines changed: 36 additions & 18 deletions

File tree

source/tests/infer/fparam_aparam-testcase.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ model_def_script:
2626
"set_davg_zero": False,
2727
"trainable": True,
2828
"type": "se_e2_a",
29-
"type_one_side": False,
29+
"type_one_side": True,
3030
},
3131
"fitting_net":
3232
{

source/tests/infer/fparam_aparam.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ model:
526526
embeddings:
527527
"@class": NetworkCollection
528528
"@version": 1
529-
ndim: 2
529+
ndim: 1
530530
network_type: embedding_network
531531
networks:
532532
- "@class": EmbeddingNetwork
@@ -916,7 +916,7 @@ model:
916916
type: se_e2_a
917917
type_map: &id001
918918
- O
919-
type_one_side: false
919+
type_one_side: true
920920
fitting:
921921
"@class": Fitting
922922
"@variables":
@@ -2012,7 +2012,7 @@ model_def_script:
20122012
set_davg_zero: false
20132013
trainable: true
20142014
type: se_e2_a
2015-
type_one_side: false
2015+
type_one_side: true
20162016
fitting_net:
20172017
activation_function: tanh
20182018
atom_ener: *id004

source/tests/infer/fparam_aparam_default.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ model:
526526
embeddings:
527527
"@class": NetworkCollection
528528
"@version": 1
529-
ndim: 2
529+
ndim: 1
530530
network_type: embedding_network
531531
networks:
532532
- "@class": EmbeddingNetwork
@@ -916,7 +916,7 @@ model:
916916
type: se_e2_a
917917
type_map: &id001
918918
- O
919-
type_one_side: false
919+
type_one_side: true
920920
fitting:
921921
"@class": Fitting
922922
"@variables":
@@ -2021,7 +2021,7 @@ model_def_script:
20212021
set_davg_zero: false
20222022
trainable: true
20232023
type: se_e2_a
2024-
type_one_side: false
2024+
type_one_side: true
20252025
fitting_net:
20262026
activation_function: tanh
20272027
atom_ener: *id004

source/tests/infer/test_models.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414

1515
from ..consistent.common import (
16+
INSTALLED_PT_EXPT,
1617
parameterized,
1718
)
1819
from .case import (
@@ -28,29 +29,44 @@
2829
"se_e2_r",
2930
"fparam_aparam",
3031
), # key
31-
(".pb", ".pth"), # model extension
32+
(".pb", ".pth", ".pte", ".pt2"), # model extension
3233
)
3334
class TestDeepPot(unittest.TestCase):
3435
# moved from tests/tf/test_deeppot_a.py
3536

3637
@classmethod
3738
def setUpClass(cls) -> None:
3839
key, extension = cls.param
40+
if extension in (".pte", ".pt2") and not INSTALLED_PT_EXPT:
41+
raise unittest.SkipTest("pt_expt backend not installed")
42+
if key in ("se_e2_a", "se_e2_r") and extension in (".pte", ".pt2"):
43+
raise unittest.SkipTest(
44+
"type_one_side=False is not supported for pt_expt export"
45+
)
46+
if key == "se_e2_r" and extension == ".pth":
47+
raise unittest.SkipTest(
48+
"se_e2_r type_one_side is not supported for PyTorch models"
49+
)
3950
cls.case = get_cases()[key]
40-
cls.model_name = cls.case.get_model(extension)
51+
if extension == ".pt2":
52+
import torch
53+
54+
# Clear default device: tests/pt/__init__.py may set a fake
55+
# device for CPU fallback, which poisons AOTInductor compilation.
56+
saved_device = torch.get_default_device()
57+
torch.set_default_device(None)
58+
try:
59+
cls.model_name = cls.case.get_model(extension)
60+
finally:
61+
torch.set_default_device(saved_device)
62+
else:
63+
cls.model_name = cls.case.get_model(extension)
4164
cls.dp = DeepEval(cls.model_name)
4265

4366
@classmethod
4467
def tearDownClass(cls) -> None:
4568
cls.dp = None
4669

47-
def setUp(self) -> None:
48-
key, extension = self.param
49-
if key == "se_e2_r" and extension == ".pth":
50-
self.skipTest(
51-
reason="se_e2_r type_one_side is not supported for PyTorch models"
52-
)
53-
5470
def test_attrs(self) -> None:
5571
assert isinstance(self.dp, DeepPot)
5672
self.assertEqual(self.dp.get_ntypes(), self.case.ntypes)
@@ -153,6 +169,8 @@ def test_1frame_atm(self) -> None:
153169

154170
def test_descriptor(self) -> None:
155171
_, extension = self.param
172+
if extension in (".pte", ".pt2"):
173+
self.skipTest("eval_descriptor not supported for pt_expt models")
156174
for ii, result in enumerate(self.case.results):
157175
if result.descriptor is None:
158176
continue
@@ -166,8 +184,8 @@ def test_descriptor(self) -> None:
166184

167185
def test_fitting_last_layer(self) -> None:
168186
_, extension = self.param
169-
if extension == ".pb":
170-
self.skipTest("fitting_last_layer not supported for TensorFlow models")
187+
if extension in (".pb", ".pte", ".pt2"):
188+
self.skipTest("fitting_last_layer not supported for this backend")
171189
for ii, result in enumerate(self.case.results):
172190
if result.fit_ll is None:
173191
continue

0 commit comments

Comments
 (0)