Skip to content

Commit 6145fe8

Browse files
author
Han Wang
committed
feat(pt_expt): add dp freeze support for pt_expt backend (.pte/.pt2)
Add freeze() function that loads a training checkpoint, reconstructs the model, serializes it, and exports via the existing deserialize_to_file pipeline. Wire the freeze command in the main() CLI dispatcher with checkpoint-dir resolution and automatic .pt2 suffix defaulting.
1 parent ac78e07 commit 6145fe8

2 files changed

Lines changed: 411 additions & 0 deletions

File tree

deepmd/pt_expt/entrypoints/main.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,54 @@ def train(
160160
trainer.run()
161161

162162

163+
def freeze(
164+
model: str,
165+
output: str = "frozen_model.pt2",
166+
head: str | None = None,
167+
) -> None:
168+
"""Freeze a pt_expt training checkpoint to .pte or .pt2 format.
169+
170+
Parameters
171+
----------
172+
model : str
173+
Path to the training checkpoint (.pt file).
174+
output : str
175+
Path for the frozen model output (.pte or .pt2).
176+
head : str or None
177+
Head to freeze in a multi-task model (not yet supported).
178+
"""
179+
import torch
180+
181+
from deepmd.pt_expt.model import (
182+
get_model,
183+
)
184+
from deepmd.pt_expt.train.wrapper import (
185+
ModelWrapper,
186+
)
187+
from deepmd.pt_expt.utils.env import (
188+
DEVICE,
189+
)
190+
from deepmd.pt_expt.utils.serialization import (
191+
deserialize_to_file,
192+
)
193+
194+
state_dict = torch.load(model, map_location=DEVICE, weights_only=True)
195+
if "model" in state_dict:
196+
state_dict = state_dict["model"]
197+
model_params = state_dict["_extra_state"]["model_params"]
198+
199+
# Reconstruct model and load weights
200+
pt_expt_model = get_model(model_params).to(DEVICE)
201+
wrapper = ModelWrapper(pt_expt_model)
202+
wrapper.load_state_dict(state_dict)
203+
pt_expt_model.eval()
204+
205+
# Serialize to dict and export
206+
model_dict = pt_expt_model.serialize()
207+
deserialize_to_file(output, {"model": model_dict})
208+
log.info(f"Saved frozen model to {output}")
209+
210+
163211
def main(args: list[str] | argparse.Namespace | None = None) -> None:
164212
"""Entry point for the pt_expt backend CLI.
165213
@@ -195,6 +243,18 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
195243
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
196244
output=FLAGS.output,
197245
)
246+
elif FLAGS.command == "freeze":
247+
if Path(FLAGS.checkpoint_folder).is_dir():
248+
checkpoint_path = Path(FLAGS.checkpoint_folder)
249+
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
250+
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))
251+
else:
252+
FLAGS.model = FLAGS.checkpoint_folder
253+
# Default to .pt2; user can specify .pte via -o flag
254+
suffix = Path(FLAGS.output).suffix
255+
if suffix not in (".pte", ".pt2"):
256+
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pt2"))
257+
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
198258
else:
199259
raise RuntimeError(
200260
f"Unsupported command '{FLAGS.command}' for the pt_expt backend."

0 commit comments

Comments
 (0)