Skip to content

Commit 80fc539

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add DeepSpin support for pt_expt backend (#5370)
## Summary - Add DeepSpin (spin model) support for the pt_expt (.pt2) backend: C++ API, Python inference, LAMMPS pair style, and model deviation - Fix `add_mag` bug in dpmodel `SpinModel`: atomic virial used `add_mag=False` while PT backend used `add_mag=True`, causing inconsistent per-atom virial between .pth and .pt2 - Add `"type": "spin_ener"` to spin model serialization for correct deserialization dispatch across all backends - Replace committed `.pth` model file with canonical `.yaml` (dpmodel serialization); `.pth` and `.pt2` are now regenerated from `.yaml` via `convert_backend` at test time ## Test plan - [x] C++ spin tests: 38 passed (12 TF-only skipped) - [x] LAMMPS .pt2 spin tests: 7 passed (PBC + NoPBC + model_devi) - [x] LAMMPS .pth spin tests: 3 passed (single-model; model_devi needs TF .pb) - [x] .pth vs .pt2 agreement: all quantities diff ~1e-15 - [ ] CI: full test suite <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Spin-capable models: public queries for model spin presence and per-type spin usage; PyTorch-exportable spin backend and new export/run backend for spin models; serialized spin model discriminator added. * Output pipeline extended to include magnetic derivatives, magnetic virial counterparts, and per-atom magnetic masks. * **Tests** * Added extensive unit and regression tests, new test suites and generators for spin models, updated golden reference data, and model-generation scripts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent d7ee913 commit 80fc539

39 files changed

+13647
-841
lines changed

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,16 @@ def get_has_efield(self) -> bool:
166166
"""Check if the model has efield."""
167167
return False
168168

169+
def get_has_spin(self) -> bool:
170+
"""Check if the model has spin atom types."""
171+
return hasattr(self.dp, "spin")
172+
173+
def get_use_spin(self) -> list[bool]:
174+
"""Get the per-type spin usage of this model."""
175+
if hasattr(self.dp, "spin"):
176+
return self.dp.spin.use_spin.tolist()
177+
return []
178+
169179
def get_ntypes_spin(self) -> int:
170180
"""Get the number of spin atom types of this model."""
171181
return 0

deepmd/dpmodel/model/base_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ def deserialize(cls, data: dict) -> "BaseBaseModel":
127127
model_type = data.get("type", "standard")
128128
if model_type == "standard":
129129
model_type = data.get("fitting", {}).get("type", "ener")
130+
if model_type == "spin_ener":
131+
# SpinModel is not a BaseModel subclass and cannot be
132+
# registered via the plugin registry. Dispatch directly.
133+
from deepmd.dpmodel.model.spin_model import (
134+
SpinModel,
135+
)
136+
137+
return SpinModel.deserialize(data)
130138
return cls.get_class_by_type(model_type).deserialize(data)
131139
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
132140

deepmd/dpmodel/model/spin_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,12 +549,15 @@ def __getattr__(self, name: str) -> Any:
549549

550550
def serialize(self) -> dict:
551551
return {
552+
"type": "spin_ener",
552553
"backbone_model": self.backbone_model.serialize(),
553554
"spin": self.spin.serialize(),
554555
}
555556

556557
@classmethod
557558
def deserialize(cls, data: dict) -> "SpinModel":
559+
data = data.copy()
560+
data.pop("type", None)
558561
backbone_model_obj = make_model(
559562
DPAtomicModel, T_Bases=(NativeOP, BaseModel)
560563
).deserialize(data["backbone_model"])
@@ -646,7 +649,7 @@ def call_common(
646649
) = self.process_spin_output(
647650
atype,
648651
model_ret[f"{var_name}_derv_c"],
649-
add_mag=False,
652+
add_mag=True,
650653
virtual_scale=False,
651654
)
652655
# Always compute mask_mag from atom types (even when forces are unavailable)
@@ -823,7 +826,7 @@ def call_common_lower(
823826
extended_atype,
824827
model_ret[f"{var_name}_derv_c"],
825828
nloc,
826-
add_mag=False,
829+
add_mag=True,
827830
virtual_scale=False,
828831
)
829832
# Always compute mask_mag from atom types (even when forces are unavailable)

deepmd/dpmodel/model/transform_output.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ModelOutputDef,
1717
OutputVariableDef,
1818
get_deriv_name,
19+
get_deriv_name_mag,
1920
get_hessian_name,
2021
get_reduce_name,
2122
)
@@ -128,6 +129,21 @@ def communicate_extended_output(
128129
model_ret[kk_derv_r],
129130
)
130131
new_ret[kk_derv_r] = force
132+
if vdef.magnetic:
133+
kk_derv_r_mag = get_deriv_name_mag(kk)[0]
134+
if model_ret.get(kk_derv_r_mag) is not None:
135+
force_mag = xp.zeros(
136+
vldims + derv_r_ext_dims,
137+
dtype=vv.dtype,
138+
device=device,
139+
)
140+
force_mag = xp_scatter_sum(
141+
force_mag,
142+
1,
143+
mapping,
144+
model_ret[kk_derv_r_mag],
145+
)
146+
new_ret[kk_derv_r_mag] = force_mag
131147
else:
132148
# name holders
133149
new_ret[kk_derv_r] = None
@@ -235,10 +251,29 @@ def communicate_extended_output(
235251
)
236252
new_ret[kk_derv_c] = virial
237253
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
254+
if vdef.magnetic:
255+
kk_derv_c_mag = get_deriv_name_mag(kk)[1]
256+
if model_ret.get(kk_derv_c_mag) is not None:
257+
virial_mag = xp.zeros(
258+
vldims + derv_c_ext_dims,
259+
dtype=vv.dtype,
260+
device=device,
261+
)
262+
virial_mag = xp_scatter_sum(
263+
virial_mag,
264+
1,
265+
mapping,
266+
model_ret[kk_derv_c_mag],
267+
)
268+
new_ret[kk_derv_c_mag] = virial_mag
238269
else:
239270
new_ret[kk_derv_c] = None
240271
new_ret[kk_derv_c + "_redu"] = None
241272
if not do_atomic_virial:
242273
# pop atomic virial, because it is not correctly calculated.
243274
new_ret.pop(kk_derv_c)
275+
# Slice mask_mag from extended to local atoms
276+
if "mask_mag" in model_ret:
277+
nloc = new_ret[next(iter(model_output_def.keys_outp()))].shape[1]
278+
new_ret["mask_mag"] = model_ret["mask_mag"][:, :nloc]
244279
return new_ret

deepmd/infer/deep_eval.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,17 @@ def get_has_spin(self) -> bool:
323323
"""Check if the model has spin atom types."""
324324
return False
325325

326+
def get_use_spin(self) -> list[bool]:
327+
"""Get the per-type spin usage of this model.
328+
329+
Returns
330+
-------
331+
list[bool]
332+
A list of bool indicating whether each atom type uses spin.
333+
Empty list if the model does not have spin.
334+
"""
335+
return []
336+
326337
def get_has_hessian(self) -> bool:
327338
"""Check if the model has hessian."""
328339
return False
@@ -705,6 +716,18 @@ def has_spin(self) -> bool:
705716
"""Check if the model has spin."""
706717
return self.deep_eval.get_has_spin()
707718

719+
@property
720+
def use_spin(self) -> list[bool]:
721+
"""Get the per-type spin usage of this model.
722+
723+
Returns
724+
-------
725+
list[bool]
726+
A list of bool indicating whether each atom type uses spin.
727+
Empty list if the model does not have spin.
728+
"""
729+
return self.deep_eval.get_use_spin()
730+
708731
@property
709732
def has_hessian(self) -> bool:
710733
"""Check if the model has hessian."""

deepmd/pd/infer/deep_eval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ def get_has_spin(self) -> bool:
297297
"""Check if the model has spin atom types."""
298298
return self._has_spin
299299

300+
def get_use_spin(self) -> list[bool]:
301+
"""Get the per-type spin usage of this model."""
302+
if self._has_spin:
303+
model = self.dp.model["Default"]
304+
return model.spin.use_spin.tolist()
305+
return []
306+
300307
def get_has_hessian(self) -> bool:
301308
"""Check if the model has hessian."""
302309
return self._has_hessian

deepmd/pt/infer/deep_eval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,13 @@ def get_has_spin(self) -> bool:
314314
"""Check if the model has spin atom types."""
315315
return self._has_spin
316316

317+
def get_use_spin(self) -> list[bool]:
318+
"""Get the per-type spin usage of this model."""
319+
if self._has_spin:
320+
model = self.dp.model["Default"]
321+
return model.spin.use_spin.tolist()
322+
return []
323+
317324
def get_has_hessian(self) -> bool:
318325
"""Check if the model has hessian."""
319326
return self._has_hessian

deepmd/pt/model/model/spin_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,12 +637,15 @@ def forward_common_lower(
637637

638638
def serialize(self) -> dict:
639639
return {
640+
"type": "spin_ener",
640641
"backbone_model": self.backbone_model.serialize(),
641642
"spin": self.spin.serialize(),
642643
}
643644

644645
@classmethod
645646
def deserialize(cls, data: dict[str, Any]) -> "SpinModel":
647+
data = data.copy()
648+
data.pop("type", None)
646649
backbone_model_obj = make_model(DPAtomicModel).deserialize(
647650
data["backbone_model"]
648651
)

deepmd/pt/utils/serialization.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
7272
"""
7373
if not model_file.endswith(".pth"):
7474
raise ValueError("PyTorch backend only supports converting .pth file")
75-
model = BaseModel.deserialize(data["model"])
75+
model_data = data["model"]
76+
if model_data.get("type") == "spin_ener":
77+
from deepmd.pt.model.model.spin_model import (
78+
SpinEnergyModel,
79+
)
80+
81+
model = SpinEnergyModel.deserialize(model_data)
82+
else:
83+
model = BaseModel.deserialize(model_data)
7684
# JIT will happy in this way...
7785
model.model_def_script = json.dumps(data["model_def_script"])
7886
if "min_nbor_dist" in data.get("@variables", {}):

0 commit comments

Comments
 (0)