Skip to content

Commit 3ab3779

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add dp freeze support and dp test tests for .pte models (#5302)
## Summary - Add `dp freeze` support for the pt_expt backend, enabling checkpoint `.pt` → exported `.pte` conversion - Add end-to-end tests for both `dp freeze` and `dp test` with `.pte` models ## Background The pt_expt backend can export models to `.pte` via `deserialize_to_file()`, and `dp test` can already load `.pte` models through the registered `DeepEval`. However, `dp freeze` was not wired up — calling `dp freeze -b pt-expt` hit `RuntimeError: Unsupported command 'freeze'`. ## Changes **`deepmd/pt_expt/entrypoints/main.py`** - Add `freeze()` function: loads `.pt` checkpoint → reconstructs model via `get_model` + `ModelWrapper` → serializes → exports to `.pte` via `deserialize_to_file` - Wire `freeze` command in `main()` dispatcher with checkpoint directory resolution and `.pte` default suffix **`source/tests/pt_expt/test_dp_freeze.py`** (new) - `test_freeze_pte` — verify `.pte` file is created from checkpoint - `test_freeze_main_dispatcher` — test `main()` CLI dispatcher with freeze command - `test_freeze_default_suffix` — verify non-`.pte` output suffix is corrected to `.pte` **`source/tests/pt_expt/test_dp_test.py`** (new) - `test_dp_test_system` — test `dp test` with `-s` system path, verify `.e.out`, `.f.out`, `.v.out` outputs - `test_dp_test_input_json` — test `dp test` with `--valid-data` JSON input ## Test plan - [x] `python -m pytest source/tests/pt_expt/test_dp_freeze.py -v` (3 passed) - [x] `python -m pytest source/tests/pt_expt/test_dp_test.py -v` (2 passed) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a "freeze" CLI command to convert PyTorch checkpoints into portable .pte model files, with output filename normalization and sensible default naming; multi-task head usage now emits a clear unsupported message. * **Tests** * Added unit tests for the freeze command and CLI dispatch behavior. * Added integration tests validating end-to-end dp_test workflows using frozen models. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent b2805fb commit 3ab3779

3 files changed

Lines changed: 308 additions & 0 deletions

File tree

deepmd/pt_expt/entrypoints/main.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,64 @@ def train(
160160
trainer.run()
161161

162162

163+
def freeze(
164+
model: str,
165+
output: str = "frozen_model.pte",
166+
head: str | None = None,
167+
) -> None:
168+
"""Freeze a pt_expt checkpoint into a .pte exported model.
169+
170+
Parameters
171+
----------
172+
model : str
173+
Path to the checkpoint file (.pt).
174+
output : str
175+
Path for the output .pte file.
176+
head : str or None
177+
Head to freeze in multi-task mode (not yet supported).
178+
"""
179+
import torch
180+
181+
from deepmd.pt_expt.model.get_model import (
182+
get_model,
183+
)
184+
from deepmd.pt_expt.train.wrapper import (
185+
ModelWrapper,
186+
)
187+
from deepmd.pt_expt.utils.env import (
188+
DEVICE,
189+
)
190+
from deepmd.pt_expt.utils.serialization import (
191+
deserialize_to_file,
192+
)
193+
194+
state_dict = torch.load(model, map_location=DEVICE, weights_only=True)
195+
if "model" in state_dict:
196+
state_dict = state_dict["model"]
197+
198+
extra_state = state_dict.get("_extra_state")
199+
if not isinstance(extra_state, dict) or "model_params" not in extra_state:
200+
raise ValueError(
201+
f"Unsupported checkpoint format at '{model}': missing "
202+
"'_extra_state.model_params' in model state dict."
203+
)
204+
model_params = extra_state["model_params"]
205+
206+
if head is not None and "model_dict" in model_params:
207+
raise NotImplementedError(
208+
"Multi-task freeze is not yet supported for the pt_expt backend."
209+
)
210+
211+
m = get_model(model_params)
212+
wrapper = ModelWrapper(m)
213+
wrapper.load_state_dict(state_dict)
214+
m.eval()
215+
216+
model_dict = m.serialize()
217+
deserialize_to_file(output, {"model": model_dict})
218+
log.info("Saved frozen model to %s", output)
219+
220+
163221
def main(args: list[str] | argparse.Namespace | None = None) -> None:
164222
"""Entry point for the pt_expt backend CLI.
165223
@@ -195,6 +253,28 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
195253
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
196254
output=FLAGS.output,
197255
)
256+
elif FLAGS.command == "freeze":
257+
if Path(FLAGS.checkpoint_folder).is_dir():
258+
checkpoint_path = Path(FLAGS.checkpoint_folder)
259+
# pt_expt training saves a symlink "model.ckpt.pt" → latest ckpt
260+
default_ckpt = checkpoint_path / "model.ckpt.pt"
261+
if default_ckpt.exists():
262+
FLAGS.model = str(default_ckpt)
263+
else:
264+
raise FileNotFoundError(
265+
f"Cannot find checkpoint in '{checkpoint_path}'. "
266+
"Expected 'model.ckpt.pt' (created by pt_expt training)."
267+
)
268+
else:
269+
model_path = Path(FLAGS.checkpoint_folder)
270+
if not model_path.exists():
271+
raise FileNotFoundError(
272+
f"Checkpoint path '{model_path}' does not exist."
273+
)
274+
FLAGS.model = str(model_path)
275+
if not FLAGS.output.endswith((".pte", ".pt2")):
276+
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
277+
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
198278
else:
199279
raise RuntimeError(
200280
f"Unsupported command '{FLAGS.command}' for the pt_expt backend."
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import argparse
3+
import os
4+
import shutil
5+
import tempfile
6+
import unittest
7+
from copy import (
8+
deepcopy,
9+
)
10+
11+
import torch
12+
13+
from deepmd.pt_expt.entrypoints.main import (
14+
freeze,
15+
main,
16+
)
17+
from deepmd.pt_expt.model.get_model import (
18+
get_model,
19+
)
20+
from deepmd.pt_expt.train.wrapper import (
21+
ModelWrapper,
22+
)
23+
24+
model_se_e2_a = {
25+
"type_map": ["O", "H", "B"],
26+
"descriptor": {
27+
"type": "se_e2_a",
28+
"sel": [46, 92, 4],
29+
"rcut_smth": 0.50,
30+
"rcut": 4.00,
31+
"neuron": [25, 50, 100],
32+
"resnet_dt": False,
33+
"axis_neuron": 16,
34+
"seed": 1,
35+
},
36+
"fitting_net": {
37+
"neuron": [24, 24, 24],
38+
"resnet_dt": True,
39+
"seed": 1,
40+
},
41+
"data_stat_nbatch": 20,
42+
}
43+
44+
45+
class TestDPFreezePtExpt(unittest.TestCase):
46+
"""Test dp freeze for the pt_expt backend."""
47+
48+
@classmethod
49+
def setUpClass(cls) -> None:
50+
cls.tmpdir = tempfile.mkdtemp()
51+
52+
# Build a model and save a fake checkpoint
53+
model_params = deepcopy(model_se_e2_a)
54+
model = get_model(model_params)
55+
wrapper = ModelWrapper(model, model_params=model_params)
56+
state_dict = wrapper.state_dict()
57+
cls.ckpt_file = os.path.join(cls.tmpdir, "model.pt")
58+
torch.save({"model": state_dict}, cls.ckpt_file)
59+
60+
@classmethod
61+
def tearDownClass(cls) -> None:
62+
shutil.rmtree(cls.tmpdir)
63+
64+
def test_freeze_pte(self) -> None:
65+
"""Freeze to .pte and verify the file is created."""
66+
output = os.path.join(self.tmpdir, "frozen_model.pte")
67+
freeze(model=self.ckpt_file, output=output)
68+
self.assertTrue(os.path.exists(output))
69+
70+
def test_freeze_main_dispatcher(self) -> None:
71+
"""Test main() CLI dispatcher with freeze command."""
72+
output_file = os.path.join(self.tmpdir, "frozen_via_main.pte")
73+
flags = argparse.Namespace(
74+
command="freeze",
75+
checkpoint_folder=self.ckpt_file,
76+
output=output_file,
77+
head=None,
78+
log_level=2, # WARNING
79+
log_path=None,
80+
)
81+
main(flags)
82+
self.assertTrue(os.path.exists(output_file))
83+
84+
def test_freeze_default_suffix(self) -> None:
85+
"""Test that main() defaults output suffix to .pte."""
86+
output_file = os.path.join(self.tmpdir, "frozen_default_suffix.pth")
87+
flags = argparse.Namespace(
88+
command="freeze",
89+
checkpoint_folder=self.ckpt_file,
90+
output=output_file,
91+
head=None,
92+
log_level=2, # WARNING
93+
log_path=None,
94+
)
95+
main(flags)
96+
expected = os.path.join(self.tmpdir, "frozen_default_suffix.pte")
97+
self.assertTrue(os.path.exists(expected))
98+
99+
100+
if __name__ == "__main__":
101+
unittest.main()
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
import os
4+
import shutil
5+
import tempfile
6+
import unittest
7+
from copy import (
8+
deepcopy,
9+
)
10+
from pathlib import (
11+
Path,
12+
)
13+
14+
import torch
15+
16+
from deepmd.entrypoints.test import test as dp_test
17+
from deepmd.pt_expt.entrypoints.main import (
18+
freeze,
19+
)
20+
from deepmd.pt_expt.model.get_model import (
21+
get_model,
22+
)
23+
from deepmd.pt_expt.train.wrapper import (
24+
ModelWrapper,
25+
)
26+
27+
model_se_e2_a = {
28+
"type_map": ["O", "H", "B"],
29+
"descriptor": {
30+
"type": "se_e2_a",
31+
"sel": [46, 92, 4],
32+
"rcut_smth": 0.50,
33+
"rcut": 4.00,
34+
"neuron": [25, 50, 100],
35+
"resnet_dt": False,
36+
"axis_neuron": 16,
37+
"seed": 1,
38+
},
39+
"fitting_net": {
40+
"neuron": [24, 24, 24],
41+
"resnet_dt": True,
42+
"seed": 1,
43+
},
44+
"data_stat_nbatch": 20,
45+
}
46+
47+
48+
class TestDPTestPtExpt(unittest.TestCase):
49+
"""Test dp test for the pt_expt backend (.pte models)."""
50+
51+
@classmethod
52+
def setUpClass(cls) -> None:
53+
cls.data_file = str(
54+
Path(__file__).parents[1] / "pt" / "water" / "data" / "single"
55+
)
56+
cls.detail_file = os.path.join(
57+
tempfile.mkdtemp(), "test_dp_test_pt_expt_detail"
58+
)
59+
cls.tmpdir = tempfile.mkdtemp()
60+
61+
# Build a model, save a checkpoint, and freeze to .pte
62+
model_params = deepcopy(model_se_e2_a)
63+
model = get_model(model_params)
64+
wrapper = ModelWrapper(model, model_params=model_params)
65+
state_dict = wrapper.state_dict()
66+
ckpt_file = os.path.join(cls.tmpdir, "model.pt")
67+
torch.save({"model": state_dict}, ckpt_file)
68+
69+
cls.pte_file = os.path.join(cls.tmpdir, "frozen_model.pte")
70+
freeze(model=ckpt_file, output=cls.pte_file)
71+
72+
@classmethod
73+
def tearDownClass(cls) -> None:
74+
shutil.rmtree(cls.tmpdir)
75+
detail_dir = os.path.dirname(cls.detail_file)
76+
if os.path.exists(detail_dir):
77+
shutil.rmtree(detail_dir)
78+
79+
def test_dp_test_system(self) -> None:
80+
"""Test dp test with -s system path."""
81+
detail = self.detail_file + "_sys"
82+
dp_test(
83+
model=self.pte_file,
84+
system=self.data_file,
85+
datafile=None,
86+
set_prefix="set",
87+
numb_test=0,
88+
rand_seed=None,
89+
shuffle_test=False,
90+
detail_file=detail,
91+
atomic=False,
92+
)
93+
self.assertTrue(os.path.exists(detail + ".e.out"))
94+
self.assertTrue(os.path.exists(detail + ".f.out"))
95+
self.assertTrue(os.path.exists(detail + ".v.out"))
96+
97+
def test_dp_test_input_json(self) -> None:
98+
"""Test dp test with --valid-data JSON input."""
99+
config = {
100+
"model": deepcopy(model_se_e2_a),
101+
"training": {
102+
"training_data": {"systems": [self.data_file]},
103+
"validation_data": {"systems": [self.data_file]},
104+
},
105+
}
106+
input_json = os.path.join(self.tmpdir, "test_input.json")
107+
with open(input_json, "w") as fp:
108+
json.dump(config, fp, indent=4)
109+
110+
detail = self.detail_file + "_json"
111+
dp_test(
112+
model=self.pte_file,
113+
system=None,
114+
datafile=None,
115+
valid_json=input_json,
116+
set_prefix="set",
117+
numb_test=0,
118+
rand_seed=None,
119+
shuffle_test=False,
120+
detail_file=detail,
121+
atomic=False,
122+
)
123+
self.assertTrue(os.path.exists(detail + ".e.out"))
124+
125+
126+
if __name__ == "__main__":
127+
unittest.main()

0 commit comments

Comments
 (0)