File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2222)
2323from deepmd .common import (
2424 expand_sys_str ,
25+ j_loader ,
2526)
2627from deepmd .loggers .loggers import (
2728 set_log_handles ,
@@ -235,8 +236,7 @@ def train(
235236 log .info ("Configuration path: %s" , input_file )
236237 if LOCAL_RANK == 0 :
237238 SummaryPrinter ()()
238- with open (input_file ) as fin :
239- config = json .load (fin )
239+ config = j_loader (input_file )
240240 # ensure suffix, as in the command line help, we say "path prefix of checkpoint files"
241241 if init_model is not None and not init_model .endswith (".pd" ):
242242 init_model += ".pd"
Original file line number Diff line number Diff line change 2525)
2626from deepmd .common import (
2727 expand_sys_str ,
28+ j_loader ,
2829)
2930from deepmd .env import (
3031 GLOBAL_CONFIG ,
@@ -254,8 +255,7 @@ def train(
254255 env .CUSTOM_OP_USE_JIT = True
255256 if LOCAL_RANK == 0 :
256257 SummaryPrinter ()()
257- with open (input_file ) as fin :
258- config = json .load (fin )
258+ config = j_loader (input_file )
259259 # ensure suffix, as in the command line help, we say "path prefix of checkpoint files"
260260 if init_model is not None and not init_model .endswith (".pt" ):
261261 init_model += ".pt"
Original file line number Diff line number Diff line change 1414
1515from deepmd .pt .entrypoints .main import (
1616 get_trainer ,
17+ train as train_entry ,
1718)
1819from deepmd .pt .utils .finetune import (
1920 get_finetune_rules ,
@@ -180,8 +181,29 @@ def setUp(self) -> None:
180181 self .config ["training" ]["numb_steps" ] = 1
181182 self .config ["training" ]["save_freq" ] = 1
182183
184+ def test_yaml_input (self ) -> None :
185+ import yaml
186+
187+ yaml_file = Path ("input.yaml" )
188+ with open (yaml_file , "w" ) as fp :
189+ yaml .safe_dump (self .config , fp )
190+ train_entry (
191+ input_file = str (yaml_file ),
192+ init_model = None ,
193+ restart = None ,
194+ finetune = None ,
195+ init_frz_model = None ,
196+ model_branch = "main" ,
197+ skip_neighbor_stat = True ,
198+ output = "out.json" ,
199+ )
200+ self .assertTrue (Path ("out.json" ).exists ())
201+
183202 def tearDown (self ) -> None :
184203 DPTrainTest .tearDown (self )
204+ for ff in ["out.json" , "input.yaml" ]:
205+ if Path (ff ).exists ():
206+ os .remove (ff )
185207
186208
187209class TestDOSModelSeA (unittest .TestCase , DPTrainTest ):
You can’t perform that action at this time.
0 commit comments