Skip to content

Commit 7e9f730

Browse files
committed
fix(pt-expt): align spin metadata output definitions
1 parent add16fa commit 7e9f730

2 files changed

Lines changed: 13 additions & 16 deletions

File tree

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,12 @@ def _init_from_model_json(self, model_json_str: str) -> None:
143143
self._sel = list(self._dpmodel.get_sel())
144144
self._mixed_types = bool(self._dpmodel.mixed_types())
145145
if self._is_spin:
146-
self._model_output_def = ModelOutputDef(
147-
FittingOutputDef(
148-
[
149-
OutputVariableDef(
150-
"energy",
151-
shape=[1],
152-
reducible=True,
153-
r_differentiable=True,
154-
c_differentiable=True,
155-
atomic=True,
156-
magnetic=True,
157-
)
158-
]
159-
)
160-
)
146+
spin_fitting_defs = self._dpmodel.model_output_def().def_outp.get_data()
147+
# Keep only physical fitting outputs; mask is derived by ModelOutputDef.
148+
fitting_defs = [
149+
vdef for name, vdef in spin_fitting_defs.items() if name != "mask"
150+
]
151+
self._model_output_def = ModelOutputDef(FittingOutputDef(fitting_defs))
161152
else:
162153
self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def())
163154

deepmd/pt_expt/utils/serialization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,15 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict:
255255
The ``fitting_output_defs`` list is also included so that
256256
``ModelOutputDef`` can be reconstructed without loading the full model.
257257
"""
258-
fitting_output_def = model.atomic_output_def()
258+
if is_spin:
259+
fitting_output_def = model.model_output_def().def_outp
260+
else:
261+
fitting_output_def = model.atomic_output_def()
259262
fitting_output_defs = []
260263
for vdef in fitting_output_def.get_data().values():
264+
# Keep metadata aligned with physical fitting outputs only.
265+
if is_spin and vdef.name == "mask":
266+
continue
261267
fitting_output_defs.append(
262268
{
263269
"name": vdef.name,

0 commit comments

Comments
 (0)