Skip to content

Commit 03ce0c7

Browse files
refine paddle unitests
1 parent 26013cb commit 03ce0c7

6 files changed

Lines changed: 21 additions & 18 deletions

File tree

source/tests/pd/model/test_descriptor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
Path,
1414
)
1515

16+
from deepmd.common import (
17+
expand_sys_str,
18+
)
1619
from deepmd.pd.model.descriptor import (
1720
prod_env_mat,
1821
)
@@ -31,9 +34,6 @@
3134
from deepmd.pd.utils.nlist import (
3235
extend_input_and_build_neighbor_list,
3336
)
34-
from deepmd.tf.common import (
35-
expand_sys_str,
36-
)
3737
from deepmd.tf.env import (
3838
op_module,
3939
)

source/tests/pd/model/test_embedding_net.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
Path,
1919
)
2020

21+
from deepmd.common import (
22+
expand_sys_str,
23+
)
2124
from deepmd.pd.model.descriptor import (
2225
DescrptSeA,
2326
)
@@ -34,9 +37,6 @@
3437
from deepmd.pd.utils.nlist import (
3538
extend_input_and_build_neighbor_list,
3639
)
37-
from deepmd.tf.common import (
38-
expand_sys_str,
39-
)
4040
from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf
4141

4242
from ..test_finetune import (

source/tests/pd/model/test_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
Path,
2020
)
2121

22+
from deepmd.common import (
23+
expand_sys_str,
24+
)
2225
from deepmd.dpmodel.utils.learning_rate import LearningRateExp as MyLRExp
2326
from deepmd.pd.loss import (
2427
EnergyStdLoss,
@@ -32,9 +35,6 @@
3235
from deepmd.pd.utils.env import (
3336
DEVICE,
3437
)
35-
from deepmd.tf.common import (
36-
expand_sys_str,
37-
)
3838
from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf
3939
from deepmd.tf.fit import (
4040
EnerFitting,

source/tests/pd/model/test_saveload_dpa1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
DataLoader,
1414
)
1515

16+
from deepmd.common import (
17+
expand_sys_str,
18+
)
1619
from deepmd.pd.loss import (
1720
EnergyStdLoss,
1821
)
@@ -32,9 +35,6 @@
3235
from deepmd.pd.utils.stat import (
3336
make_stat_input,
3437
)
35-
from deepmd.tf.common import (
36-
expand_sys_str,
37-
)
3838

3939

4040
def get_dataset(config):

source/tests/pd/model/test_saveload_se_e2_a.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
DataLoader,
1414
)
1515

16+
from deepmd.common import (
17+
expand_sys_str,
18+
)
1619
from deepmd.pd.loss import (
1720
EnergyStdLoss,
1821
)
@@ -32,9 +35,6 @@
3235
from deepmd.pd.utils.stat import (
3336
make_stat_input,
3437
)
35-
from deepmd.tf.common import (
36-
expand_sys_str,
37-
)
3838

3939

4040
def get_dataset(config):

source/tests/pd/test_training.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,14 @@ def test_trainable(self) -> None:
136136
def tearDown(self) -> None:
137137
for f in os.listdir("."):
138138
if f.startswith("model") and f.endswith(".pd"):
139-
os.remove(f)
139+
if os.path.exists(f):
140+
os.remove(f)
140141
if f in ["lcurve.out"]:
141-
os.remove(f)
142+
if os.path.exists(f):
143+
os.remove(f)
142144
if f in ["stat_files"]:
143-
shutil.rmtree(f)
145+
if os.path.exists(f):
146+
shutil.rmtree(f)
144147

145148

146149
class TestEnergyModelSeA(unittest.TestCase, DPTrainTest):

0 commit comments

Comments
 (0)