File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1313 Optional ,
1414)
1515
16- from deepmd .tf . common import (
16+ from deepmd .common import (
1717 j_loader ,
1818)
1919from deepmd .tf .env import (
Original file line number Diff line number Diff line change 1+ # SPDX-License-Identifier: LGPL-3.0-or-later
2+ """Tests for TensorFlow training entrypoint."""
3+
4+ from pathlib import (
5+ Path ,
6+ )
7+ from typing import (
8+ Any ,
9+ )
10+
11+ import yaml
12+
13+ from deepmd .tf .entrypoints .train import ( # type: ignore
14+ train ,
15+ )
16+
17+ from .common import (
18+ del_data ,
19+ gen_data_type_specific ,
20+ j_loader ,
21+ )
22+
23+
24+ class TestYamlInput :
25+ """Ensure training entrypoint accepts YAML config."""
26+
27+ def setup_method (self ) -> None :
28+ gen_data_type_specific ()
29+ config : dict [str , Any ] = j_loader ("water_se_atten.json" )
30+ config ["systems" ] = ["system" ]
31+ config ["stop_batch" ] = 1
32+ config ["save_freq" ] = 1
33+ yaml_file = Path ("input.yaml" )
34+ with open (yaml_file , "w" ) as fp :
35+ yaml .safe_dump (config , fp )
36+ self .yaml_file = yaml_file
37+
38+ def teardown_method (self ) -> None :
39+ del_data ()
40+ for ff in [
41+ "out.json" ,
42+ "input.yaml" ,
43+ "lcurve.out" ,
44+ "model.ckpt.data-00000-of-00001" ,
45+ "model.ckpt.index" ,
46+ "model.ckpt.meta" ,
47+ ]:
48+ Path (ff ).unlink (missing_ok = True )
49+
50+ def test_yaml_input (self ) -> None :
51+ train (
52+ INPUT = str (self .yaml_file ),
53+ init_model = None ,
54+ restart = None ,
55+ output = "out.json" ,
56+ init_frz_model = None ,
57+ mpi_log = "master" ,
58+ log_level = 0 ,
59+ log_path = None ,
60+ skip_neighbor_stat = True ,
61+ )
62+ assert Path ("out.json" ).exists ()
63+
64+
65+ __all__ = ["TestYamlInput" ]
You can’t perform that action at this time.
0 commit comments