Skip to content

Commit 1e75042

Browse files
feat: add yaml input file support (#4894)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Training entrypoints now accept YAML configuration files in addition to JSON, offering more flexibility when launching training. * Unified configuration loading across frameworks for consistent behavior (PyTorch, Paddle, TensorFlow). * Backward compatible: existing JSON-based workflows continue to work unchanged. * **Tests** * Added coverage to verify YAML input produces the expected training output. * Improved test cleanup to remove generated artifacts after execution. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1525a79 commit 1e75042

4 files changed

Lines changed: 27 additions & 5 deletions

File tree

deepmd/pd/entrypoints/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from deepmd.common import (
2424
expand_sys_str,
25+
j_loader,
2526
)
2627
from deepmd.loggers.loggers import (
2728
set_log_handles,
@@ -235,8 +236,7 @@ def train(
235236
log.info("Configuration path: %s", input_file)
236237
if LOCAL_RANK == 0:
237238
SummaryPrinter()()
238-
with open(input_file) as fin:
239-
config = json.load(fin)
239+
config = j_loader(input_file)
240240
# ensure suffix, as in the command line help, we say "path prefix of checkpoint files"
241241
if init_model is not None and not init_model.endswith(".pd"):
242242
init_model += ".pd"

deepmd/pt/entrypoints/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from deepmd.common import (
2727
expand_sys_str,
28+
j_loader,
2829
)
2930
from deepmd.env import (
3031
GLOBAL_CONFIG,
@@ -254,8 +255,7 @@ def train(
254255
env.CUSTOM_OP_USE_JIT = True
255256
if LOCAL_RANK == 0:
256257
SummaryPrinter()()
257-
with open(input_file) as fin:
258-
config = json.load(fin)
258+
config = j_loader(input_file)
259259
# ensure suffix, as in the command line help, we say "path prefix of checkpoint files"
260260
if init_model is not None and not init_model.endswith(".pt"):
261261
init_model += ".pt"

deepmd/tf/entrypoints/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Optional,
1414
)
1515

16-
from deepmd.tf.common import (
16+
from deepmd.common import (
1717
j_loader,
1818
)
1919
from deepmd.tf.env import (

source/tests/pt/test_training.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from deepmd.pt.entrypoints.main import (
1616
get_trainer,
1717
)
18+
from deepmd.pt.entrypoints.main import train as train_entry
1819
from deepmd.pt.utils.finetune import (
1920
get_finetune_rules,
2021
)
@@ -180,8 +181,29 @@ def setUp(self) -> None:
180181
self.config["training"]["numb_steps"] = 1
181182
self.config["training"]["save_freq"] = 1
182183

184+
def test_yaml_input(self) -> None:
185+
import yaml
186+
187+
yaml_file = Path("input.yaml")
188+
with open(yaml_file, "w") as fp:
189+
yaml.safe_dump(self.config, fp)
190+
train_entry(
191+
input_file=str(yaml_file),
192+
init_model=None,
193+
restart=None,
194+
finetune=None,
195+
init_frz_model=None,
196+
model_branch="main",
197+
skip_neighbor_stat=True,
198+
output="out.json",
199+
)
200+
self.assertTrue(Path("out.json").exists())
201+
183202
def tearDown(self) -> None:
184203
DPTrainTest.tearDown(self)
204+
for ff in ["out.json", "input.yaml"]:
205+
if Path(ff).exists():
206+
os.remove(ff)
185207

186208

187209
class TestDOSModelSeA(unittest.TestCase, DPTrainTest):

0 commit comments

Comments
 (0)