Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/tests/infer/fparam_aparam-testcase.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ model_def_script:
"set_davg_zero": False,
"trainable": True,
"type": "se_e2_a",
"type_one_side": False,
"type_one_side": True,
},
"fitting_net":
{
Expand Down
6 changes: 3 additions & 3 deletions source/tests/infer/fparam_aparam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ model:
embeddings:
"@class": NetworkCollection
"@version": 1
ndim: 2
ndim: 1
network_type: embedding_network
networks:
- "@class": EmbeddingNetwork
Expand Down Expand Up @@ -916,7 +916,7 @@ model:
type: se_e2_a
type_map: &id001
- O
type_one_side: false
type_one_side: true
fitting:
"@class": Fitting
"@variables":
Expand Down Expand Up @@ -2012,7 +2012,7 @@ model_def_script:
set_davg_zero: false
trainable: true
type: se_e2_a
type_one_side: false
type_one_side: true
fitting_net:
activation_function: tanh
atom_ener: *id004
Expand Down
6 changes: 3 additions & 3 deletions source/tests/infer/fparam_aparam_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ model:
embeddings:
"@class": NetworkCollection
"@version": 1
ndim: 2
ndim: 1
network_type: embedding_network
networks:
- "@class": EmbeddingNetwork
Expand Down Expand Up @@ -916,7 +916,7 @@ model:
type: se_e2_a
type_map: &id001
- O
type_one_side: false
type_one_side: true
fitting:
"@class": Fitting
"@variables":
Expand Down Expand Up @@ -2021,7 +2021,7 @@ model_def_script:
set_davg_zero: false
trainable: true
type: se_e2_a
type_one_side: false
type_one_side: true
fitting_net:
activation_function: tanh
atom_ener: *id004
Expand Down
26 changes: 16 additions & 10 deletions source/tests/infer/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from ..consistent.common import (
INSTALLED_PT_EXPT,
parameterized,
)
from .case import (
Expand All @@ -28,14 +29,24 @@
"se_e2_r",
"fparam_aparam",
), # key
(".pb", ".pth"), # model extension
(".pb", ".pth", ".pte", ".pt2"), # model extension
)
class TestDeepPot(unittest.TestCase):
# moved from tests/tf/test_deeppot_a.py

@classmethod
def setUpClass(cls) -> None:
key, extension = cls.param
if extension in (".pte", ".pt2") and not INSTALLED_PT_EXPT:
raise unittest.SkipTest("pt_expt backend not installed")
if key in ("se_e2_a", "se_e2_r") and extension in (".pte", ".pt2"):
raise unittest.SkipTest(
"type_one_side=False is not supported for pt_expt export"
)
if key == "se_e2_r" and extension == ".pth":
raise unittest.SkipTest(
"se_e2_r type_one_side is not supported for PyTorch models"
)
cls.case = get_cases()[key]
cls.model_name = cls.case.get_model(extension)
cls.dp = DeepEval(cls.model_name)
Expand All @@ -44,13 +55,6 @@ def setUpClass(cls) -> None:
def tearDownClass(cls) -> None:
cls.dp = None

def setUp(self) -> None:
key, extension = self.param
if key == "se_e2_r" and extension == ".pth":
self.skipTest(
reason="se_e2_r type_one_side is not supported for PyTorch models"
)

def test_attrs(self) -> None:
assert isinstance(self.dp, DeepPot)
self.assertEqual(self.dp.get_ntypes(), self.case.ntypes)
Expand Down Expand Up @@ -153,6 +157,8 @@ def test_1frame_atm(self) -> None:

def test_descriptor(self) -> None:
_, extension = self.param
if extension in (".pte", ".pt2"):
self.skipTest("eval_descriptor not supported for pt_expt models")
for ii, result in enumerate(self.case.results):
if result.descriptor is None:
continue
Expand All @@ -166,8 +172,8 @@ def test_descriptor(self) -> None:

def test_fitting_last_layer(self) -> None:
_, extension = self.param
if extension == ".pb":
self.skipTest("fitting_last_layer not supported for TensorFlow models")
if extension in (".pb", ".pte", ".pt2"):
self.skipTest("fitting_last_layer not supported for this backend")
for ii, result in enumerate(self.case.results):
if result.fit_ll is None:
continue
Expand Down
Loading