@@ -160,6 +160,54 @@ def train(
160160 trainer .run ()
161161
162162
163+ def freeze (
164+ model : str ,
165+ output : str = "frozen_model.pt2" ,
166+ head : str | None = None ,
167+ ) -> None :
168+ """Freeze a pt_expt training checkpoint to .pte or .pt2 format.
169+
170+ Parameters
171+ ----------
172+ model : str
173+ Path to the training checkpoint (.pt file).
174+ output : str
175+ Path for the frozen model output (.pte or .pt2).
176+ head : str or None
177+ Head to freeze in a multi-task model (not yet supported).
178+ """
179+ import torch
180+
181+ from deepmd .pt_expt .model import (
182+ get_model ,
183+ )
184+ from deepmd .pt_expt .train .wrapper import (
185+ ModelWrapper ,
186+ )
187+ from deepmd .pt_expt .utils .env import (
188+ DEVICE ,
189+ )
190+ from deepmd .pt_expt .utils .serialization import (
191+ deserialize_to_file ,
192+ )
193+
194+ state_dict = torch .load (model , map_location = DEVICE , weights_only = True )
195+ if "model" in state_dict :
196+ state_dict = state_dict ["model" ]
197+ model_params = state_dict ["_extra_state" ]["model_params" ]
198+
199+ # Reconstruct model and load weights
200+ pt_expt_model = get_model (model_params ).to (DEVICE )
201+ wrapper = ModelWrapper (pt_expt_model )
202+ wrapper .load_state_dict (state_dict )
203+ pt_expt_model .eval ()
204+
205+ # Serialize to dict and export
206+ model_dict = pt_expt_model .serialize ()
207+ deserialize_to_file (output , {"model" : model_dict })
208+ log .info (f"Saved frozen model to { output } " )
209+
210+
163211def main (args : list [str ] | argparse .Namespace | None = None ) -> None :
164212 """Entry point for the pt_expt backend CLI.
165213
@@ -195,6 +243,18 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
195243 skip_neighbor_stat = FLAGS .skip_neighbor_stat ,
196244 output = FLAGS .output ,
197245 )
246+ elif FLAGS .command == "freeze" :
247+ if Path (FLAGS .checkpoint_folder ).is_dir ():
248+ checkpoint_path = Path (FLAGS .checkpoint_folder )
249+ latest_ckpt_file = (checkpoint_path / "checkpoint" ).read_text ()
250+ FLAGS .model = str (checkpoint_path .joinpath (latest_ckpt_file ))
251+ else :
252+ FLAGS .model = FLAGS .checkpoint_folder
253+ # Default to .pt2; user can specify .pte via -o flag
254+ suffix = Path (FLAGS .output ).suffix
255+ if suffix not in (".pte" , ".pt2" ):
256+ FLAGS .output = str (Path (FLAGS .output ).with_suffix (".pt2" ))
257+ freeze (model = FLAGS .model , output = FLAGS .output , head = FLAGS .head )
198258 else :
199259 raise RuntimeError (
200260 f"Unsupported command '{ FLAGS .command } ' for the pt_expt backend."
0 commit comments