Skip to content

Commit c2f8356

Browse files
committed
test: ensure tf training accepts yaml
1 parent 22f394a commit c2f8356

2 files changed

Lines changed: 66 additions & 1 deletion

File tree

deepmd/tf/entrypoints/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Optional,
1414
)
1515

16-
from deepmd.tf.common import (
16+
from deepmd.common import (
1717
j_loader,
1818
)
1919
from deepmd.tf.env import (

source/tests/tf/test_training.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Tests for TensorFlow training entrypoint."""
3+
4+
from pathlib import (
5+
Path,
6+
)
7+
from typing import (
8+
Any,
9+
)
10+
11+
import yaml
12+
13+
from deepmd.tf.entrypoints.train import ( # type: ignore
14+
train,
15+
)
16+
17+
from .common import (
18+
del_data,
19+
gen_data_type_specific,
20+
j_loader,
21+
)
22+
23+
24+
class TestYamlInput:
25+
"""Ensure training entrypoint accepts YAML config."""
26+
27+
def setup_method(self) -> None:
28+
gen_data_type_specific()
29+
config: dict[str, Any] = j_loader("water_se_atten.json")
30+
config["systems"] = ["system"]
31+
config["stop_batch"] = 1
32+
config["save_freq"] = 1
33+
yaml_file = Path("input.yaml")
34+
with open(yaml_file, "w") as fp:
35+
yaml.safe_dump(config, fp)
36+
self.yaml_file = yaml_file
37+
38+
def teardown_method(self) -> None:
39+
del_data()
40+
for ff in [
41+
"out.json",
42+
"input.yaml",
43+
"lcurve.out",
44+
"model.ckpt.data-00000-of-00001",
45+
"model.ckpt.index",
46+
"model.ckpt.meta",
47+
]:
48+
Path(ff).unlink(missing_ok=True)
49+
50+
def test_yaml_input(self) -> None:
51+
train(
52+
INPUT=str(self.yaml_file),
53+
init_model=None,
54+
restart=None,
55+
output="out.json",
56+
init_frz_model=None,
57+
mpi_log="master",
58+
log_level=0,
59+
log_path=None,
60+
skip_neighbor_stat=True,
61+
)
62+
assert Path("out.json").exists()
63+
64+
65+
__all__ = ["TestYamlInput"]

0 commit comments

Comments
 (0)