Skip to content

Commit 98aee78

Browse files
author
Han Wang
committed
feat(pt_expt): support .pt training checkpoints in DeepEval
`dp --pt-expt test -m foo.pt` previously rejected `.pt` files (only `.pt2` / `.pte` were supported), and `dp --pt test -m foo.pt` on a pt_expt-trained checkpoint silently loaded random weights because the state-dict layout (dpmodel `.w`/`.b` keys) doesn't match the legacy pt backend's expectations. - `Backend.detect_backend_by_model` sniffs `.pt` content so files with `.w`/`.b` keys (pt_expt) route to the pt_expt DeepEval and files with `.matrix`/`.bias` keys (pt) keep routing to pt. - `pt_expt.DeepEval._load_pt` reconstructs the model from `_extra_state["model_params"]`, loads the state-dict via `ModelWrapper`, and exposes an eager `forward_common_lower` runner with the same signature as the AOTI/exported module so the existing `eval()` path is unchanged. Spin-aware and non-spin variants; multi-task `.pt` selects a head and remaps keys. - `pt_expt.get_model` learns `get_spin_model` (mirrors dpmodel) so spin checkpoints can be reconstructed from `model_params`. - Tests cover dispatch sniffing, single-task / multi-task / spin / spin-multi-task `.pt` parity vs eager forward, fparam / aparam, and `.pt` vs `.pte` cross-format consistency at 1e-10.
1 parent d14233e commit 98aee78

4 files changed

Lines changed: 1016 additions & 3 deletions

File tree

deepmd/backend/backend.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,33 @@ def detect_backend_by_model(filename: str) -> type["Backend"]:
101101
filename : str
102102
The model file name
103103
"""
104-
filename = str(filename).lower()
104+
filename_lower = str(filename).lower()
105+
# `.pt` is shared between the pt and pt_expt backends. They use
106+
# different parameter naming (pt: `.matrix`/`.bias`, pt_expt:
107+
# `.w`/`.b`), so peek at the state-dict keys to disambiguate.
108+
if filename_lower.endswith(".pt"):
109+
try:
110+
import torch
111+
112+
sd = torch.load(filename, map_location="cpu", weights_only=False)
113+
if isinstance(sd, dict) and "model" in sd:
114+
sd = sd["model"]
115+
keys = list(sd.keys()) if hasattr(sd, "keys") else []
116+
has_pt_expt = any(k.endswith(".w") or k.endswith(".b") for k in keys)
117+
has_pt = any(k.endswith(".matrix") or k.endswith(".bias") for k in keys)
118+
if has_pt_expt and not has_pt:
119+
target_name = "pt-expt"
120+
else:
121+
target_name = "pt"
122+
for key, backend in Backend.get_backends().items():
123+
if key == target_name:
124+
return backend
125+
except Exception:
126+
# Fall through to suffix matching if sniffing fails.
127+
pass
105128
for backend in Backend.get_backends().values():
106129
for suffix in backend.suffixes:
107-
if filename.endswith(suffix):
130+
if filename_lower.endswith(suffix):
108131
return backend
109132
raise ValueError(f"Cannot detect the backend of the model file {filename}.")
110133

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,16 @@ def __init__(
9999

100100
if self._is_pt2:
101101
self._load_pt2(model_file)
102-
else:
102+
elif model_file.endswith(".pte"):
103103
self._load_pte(model_file)
104+
elif model_file.endswith(".pt"):
105+
self._load_pt(model_file, head=kwargs.get("head"))
106+
else:
107+
raise ValueError(
108+
f"Unsupported model file '{model_file}' for the pt_expt "
109+
"backend: expected `.pt2` / `.pte` (deployable archives) or "
110+
"`.pt` (training checkpoint)."
111+
)
104112

105113
if isinstance(auto_batch_size, bool):
106114
if auto_batch_size:
@@ -206,6 +214,160 @@ def _load_pt2(self, model_file: str) -> None:
206214
self._pt2_runner = aoti_load_package(model_file)
207215
self.exported_module = None
208216

217+
def _load_pt(self, model_file: str, head: str | None = None) -> None:
218+
"""Load a `.pt` training checkpoint (eager mode, no torch.export)."""
219+
from copy import (
220+
deepcopy,
221+
)
222+
223+
from deepmd.pt.utils.env import (
224+
DEVICE,
225+
)
226+
from deepmd.pt_expt.model import (
227+
get_model,
228+
)
229+
230+
state_dict = torch.load(model_file, map_location=DEVICE, weights_only=False)
231+
if "model" in state_dict:
232+
state_dict = state_dict["model"]
233+
model_params = deepcopy(state_dict["_extra_state"]["model_params"])
234+
235+
if "model_dict" in model_params:
236+
# Multi-task: pick the requested head (defaults to "Default" if present).
237+
heads = list(model_params["model_dict"].keys())
238+
if head is None:
239+
if "Default" in heads:
240+
head = "Default"
241+
else:
242+
raise ValueError(
243+
f"Multi-task checkpoint '{model_file}' has heads "
244+
f"{heads}; pass --head to select one."
245+
)
246+
if head not in heads:
247+
raise ValueError(
248+
f"Head '{head}' not found in checkpoint '{model_file}'. "
249+
f"Available heads: {heads}."
250+
)
251+
head_params = model_params["model_dict"][head]
252+
# Restrict state_dict to the chosen head and rename to "Default".
253+
head_state = {"_extra_state": state_dict["_extra_state"]}
254+
for key, value in state_dict.items():
255+
prefix = f"model.{head}."
256+
if key.startswith(prefix):
257+
head_state[key.replace(prefix, "model.Default.")] = (
258+
value.clone() if torch.is_tensor(value) else value
259+
)
260+
state_dict = head_state
261+
model_params = head_params
262+
263+
model = get_model(deepcopy(model_params)).to(DEVICE)
264+
265+
# Load weights into a {"Default": model} wrapper to match the
266+
# `model.Default.*` key prefix used in the saved state_dict.
267+
from deepmd.pt_expt.train.wrapper import (
268+
ModelWrapper,
269+
)
270+
271+
wrapper = ModelWrapper(model)
272+
wrapper.load_state_dict(state_dict)
273+
model = wrapper.model["Default"].eval()
274+
275+
self._dpmodel = model
276+
self._is_spin = (
277+
model_params.get("type") == "spin_ener" or "spin" in model_params
278+
)
279+
self.rcut = model.get_rcut()
280+
self.type_map = model.get_type_map()
281+
if self._is_spin:
282+
self._model_output_def = ModelOutputDef(
283+
FittingOutputDef(
284+
[
285+
OutputVariableDef(
286+
"energy",
287+
shape=[1],
288+
reducible=True,
289+
r_differentiable=True,
290+
c_differentiable=True,
291+
atomic=True,
292+
magnetic=True,
293+
)
294+
]
295+
)
296+
)
297+
else:
298+
self._model_output_def = ModelOutputDef(model.atomic_output_def())
299+
self._model_def_script = model_params
300+
# Populate metadata so eval helpers (e.g. default_fparam fallback)
301+
# behave the same as the .pt2/.pte path. Mirrors the fields that
302+
# `_collect_metadata` writes into metadata.json.
303+
self.metadata = {
304+
"type_map": model.get_type_map(),
305+
"rcut": model.get_rcut(),
306+
"sel": model.get_sel(),
307+
"dim_fparam": model.get_dim_fparam(),
308+
"dim_aparam": model.get_dim_aparam(),
309+
"mixed_types": model.mixed_types(),
310+
"has_default_fparam": model.has_default_fparam(),
311+
"default_fparam": model.get_default_fparam(),
312+
"is_spin": self._is_spin,
313+
}
314+
if self._is_spin:
315+
self.metadata["ntypes_spin"] = model.spin.get_ntypes_spin()
316+
self.metadata["use_spin"] = [bool(v) for v in model.spin.use_spin]
317+
318+
# Eager runner with the same signature as the .pt2/.pte exported module.
319+
# Use forward_common_lower (not forward_lower) to match the export-time
320+
# output keys ("energy", "energy_redu", "energy_derv_r", ...) that
321+
# communicate_extended_output downstream consumes.
322+
# Non-spin: (ext_coord, ext_atype, nlist, mapping, fparam, aparam)
323+
# Spin: (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam)
324+
if self._is_spin:
325+
326+
def _eager_runner_spin(
327+
ext_coord: torch.Tensor,
328+
ext_atype: torch.Tensor,
329+
ext_spin: torch.Tensor,
330+
nlist: torch.Tensor,
331+
mapping: torch.Tensor | None,
332+
fparam: torch.Tensor | None,
333+
aparam: torch.Tensor | None,
334+
) -> dict[str, torch.Tensor]:
335+
ext_coord = ext_coord.detach().requires_grad_(True)
336+
return model.forward_common_lower(
337+
ext_coord,
338+
ext_atype,
339+
ext_spin,
340+
nlist,
341+
mapping,
342+
fparam=fparam,
343+
aparam=aparam,
344+
do_atomic_virial=True,
345+
)
346+
347+
self.exported_module = _eager_runner_spin
348+
else:
349+
350+
def _eager_runner(
351+
ext_coord: torch.Tensor,
352+
ext_atype: torch.Tensor,
353+
nlist: torch.Tensor,
354+
mapping: torch.Tensor | None,
355+
fparam: torch.Tensor | None,
356+
aparam: torch.Tensor | None,
357+
) -> dict[str, torch.Tensor]:
358+
ext_coord = ext_coord.detach().requires_grad_(True)
359+
return model.forward_common_lower(
360+
ext_coord,
361+
ext_atype,
362+
nlist,
363+
mapping,
364+
fparam=fparam,
365+
aparam=aparam,
366+
do_atomic_virial=True,
367+
)
368+
369+
self.exported_module = _eager_runner
370+
209371
def get_rcut(self) -> float:
210372
"""Get the cutoff radius of this model."""
211373
return self.rcut

deepmd/pt_expt/model/get_model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
from deepmd.pt_expt.model.property_model import (
3838
PropertyModel,
3939
)
40+
from deepmd.pt_expt.model.spin_ener_model import (
41+
SpinEnergyModel,
42+
)
43+
from deepmd.utils.spin import (
44+
Spin,
45+
)
4046

4147

4248
def _get_standard_model_components(
@@ -162,6 +168,36 @@ def get_linear_model(model_params: dict) -> BaseModel:
162168
)
163169

164170

171+
def get_spin_model(data: dict) -> SpinEnergyModel:
172+
"""Build a pt_expt spin energy model from a config dictionary.
173+
174+
Mirrors :func:`deepmd.dpmodel.model.model.get_spin_model`: expands the
175+
type map and descriptor sel for virtual spin atoms, then wraps the
176+
backbone EnergyModel as a :class:`SpinEnergyModel`.
177+
"""
178+
data = copy.deepcopy(data)
179+
data["type_map"] += [item + "_spin" for item in data["type_map"]]
180+
spin = Spin(
181+
use_spin=data["spin"]["use_spin"],
182+
virtual_scale=data["spin"]["virtual_scale"],
183+
)
184+
pair_exclude_types = spin.get_pair_exclude_types(
185+
exclude_types=data.get("pair_exclude_types", None)
186+
)
187+
data["pair_exclude_types"] = pair_exclude_types
188+
data["descriptor"]["exclude_types"] = pair_exclude_types
189+
atom_exclude_types = spin.get_atom_exclude_types(
190+
exclude_types=data.get("atom_exclude_types", None)
191+
)
192+
data["atom_exclude_types"] = atom_exclude_types
193+
if "env_protection" not in data["descriptor"]:
194+
data["descriptor"]["env_protection"] = 1e-6
195+
if data["descriptor"]["type"] in ["se_e2_a"]:
196+
data["descriptor"]["sel"] += data["descriptor"]["sel"]
197+
backbone_model = get_standard_model(data)
198+
return SpinEnergyModel(backbone_model=backbone_model, spin=spin)
199+
200+
165201
def get_model(data: dict) -> BaseModel:
166202
"""Get a model from a config dictionary.
167203
@@ -172,6 +208,8 @@ def get_model(data: dict) -> BaseModel:
172208
"""
173209
model_type = data.get("type", "standard")
174210
if model_type == "standard":
211+
if "spin" in data:
212+
return get_spin_model(data)
175213
return get_standard_model(data)
176214
elif model_type == "linear_ener":
177215
return get_linear_model(data)

0 commit comments

Comments
 (0)