diff --git a/deepmd/pd/entrypoints/main.py b/deepmd/pd/entrypoints/main.py index 3ef075f359..4e47dbfe77 100644 --- a/deepmd/pd/entrypoints/main.py +++ b/deepmd/pd/entrypoints/main.py @@ -22,6 +22,7 @@ ) from deepmd.common import ( expand_sys_str, + j_loader, ) from deepmd.loggers.loggers import ( set_log_handles, @@ -235,8 +236,7 @@ def train( log.info("Configuration path: %s", input_file) if LOCAL_RANK == 0: SummaryPrinter()() - with open(input_file) as fin: - config = json.load(fin) + config = j_loader(input_file) # ensure suffix, as in the command line help, we say "path prefix of checkpoint files" if init_model is not None and not init_model.endswith(".pd"): init_model += ".pd" diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 0e248583ec..630fb6d86f 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -25,6 +25,7 @@ ) from deepmd.common import ( expand_sys_str, + j_loader, ) from deepmd.env import ( GLOBAL_CONFIG, @@ -254,8 +255,7 @@ def train( env.CUSTOM_OP_USE_JIT = True if LOCAL_RANK == 0: SummaryPrinter()() - with open(input_file) as fin: - config = json.load(fin) + config = j_loader(input_file) # ensure suffix, as in the command line help, we say "path prefix of checkpoint files" if init_model is not None and not init_model.endswith(".pt"): init_model += ".pt" diff --git a/deepmd/tf/entrypoints/train.py b/deepmd/tf/entrypoints/train.py index b12e4fe1af..5bcca9a4e3 100755 --- a/deepmd/tf/entrypoints/train.py +++ b/deepmd/tf/entrypoints/train.py @@ -13,7 +13,7 @@ Optional, ) -from deepmd.tf.common import ( +from deepmd.common import ( j_loader, ) from deepmd.tf.env import ( diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index c57c896197..da239212b0 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -15,6 +15,7 @@ from deepmd.pt.entrypoints.main import ( get_trainer, ) +from deepmd.pt.entrypoints.main import train as train_entry from deepmd.pt.utils.finetune import ( get_finetune_rules, ) @@ -180,8 +181,29 @@ def setUp(self) -> None: self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 + def test_yaml_input(self) -> None: + import yaml + + yaml_file = Path("input.yaml") + with open(yaml_file, "w") as fp: + yaml.safe_dump(self.config, fp) + train_entry( + input_file=str(yaml_file), + init_model=None, + restart=None, + finetune=None, + init_frz_model=None, + model_branch="main", + skip_neighbor_stat=True, + output="out.json", + ) + self.assertTrue(Path("out.json").exists()) + def tearDown(self) -> None: DPTrainTest.tearDown(self) + for ff in ["out.json", "input.yaml"]: + if Path(ff).exists(): + os.remove(ff) class TestDOSModelSeA(unittest.TestCase, DPTrainTest):