Skip to content

Commit cdd9b1a

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add missing losses (spin, DOS, tensor, property) (#5345)
## Summary - Add `EnergySpinLoss`, `DOSLoss`, `TensorLoss`, `PropertyLoss` to dpmodel (array_api) and pt_expt (`@torch_module` wrapper) backends - Add `serialize()`/`deserialize()` to PT loss classes (`ener_spin`, `dos`, `tensor`, `property`) for cross-backend consistency testing - Add `mae=True` support to dpmodel `EnergyLoss` and `EnergySpinLoss` (extra MAE metrics for `dp test`) - Extend pt_expt `get_loss()` factory to handle all loss types - Add cross-backend consistency tests parameterized over `loss_func` and `mae` - Add pt_expt unit tests for all 4 new losses ## Details ### dpmodel losses (`deepmd/dpmodel/loss/`) Array-API compatible implementations ported from `deepmd/pt/loss/`. Key adaptations: - Boolean fancy indexing replaced with mask multiplication + manual mean - `torch.cumsum` → `xp.cumulative_sum` - No input dict mutation (PT mutates `model_pred` in-place) ### pt_expt wrappers (`deepmd/pt_expt/loss/`) Thin `@torch_module` wrappers inheriting from dpmodel classes, following existing `EnergyLoss` pattern. ### Known limitations 1. **PropertyLoss normalization**: `out_std`/`out_bias` must be provided explicitly (dpmodel losses can't access the model at forward time). Defaults to identity normalization if not provided. 2. **`inference` parameter**: Not ported from PT losses — it only suppresses `l2_*` intermediate metrics and is never actually used (`inference=True` is never constructed anywhere in the codebase). 3. **`denoise` loss**: Not ported (internal DPA pretraining, not user-facing). ## Test plan - [x] `python -m pytest source/tests/pt_expt/loss/ -v` — all 42 tests pass - [x] `python -m pytest source/tests/consistent/loss/ -v` — all 342 tests pass (358 skipped for unavailable backends) - [x] `python -m pytest source/tests/universal/dpmodel/loss/ -v` — existing tests still pass <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added DOS, EnergySpin, Property, and Tensor loss types; energy/force losses can optionally report MAE. * Loss factory extended to select these types and incorporate model metadata. * **Integration** * Loss implementations exposed to PyTorch-experimental with serialization/deserialization support for round-trips. * **Tests** * Extensive cross-backend/unit tests for new losses, metrics, masking behaviors, and serialization. * **Chores** * Public exports updated to include the new loss classes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 2a8fc21 commit cdd9b1a

28 files changed

+2903
-16
lines changed

deepmd/dpmodel/loss/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,24 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.loss.dos import (
3+
DOSLoss,
4+
)
5+
from deepmd.dpmodel.loss.ener import (
6+
EnergyLoss,
7+
)
8+
from deepmd.dpmodel.loss.ener_spin import (
9+
EnergySpinLoss,
10+
)
11+
from deepmd.dpmodel.loss.property import (
12+
PropertyLoss,
13+
)
14+
from deepmd.dpmodel.loss.tensor import (
15+
TensorLoss,
16+
)
17+
18+
__all__ = [
19+
"DOSLoss",
20+
"EnergyLoss",
21+
"EnergySpinLoss",
22+
"PropertyLoss",
23+
"TensorLoss",
24+
]

deepmd/dpmodel/loss/dos.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
import array_api_compat
7+
8+
from deepmd.dpmodel.array_api import (
9+
Array,
10+
)
11+
from deepmd.dpmodel.loss.loss import (
12+
Loss,
13+
)
14+
from deepmd.utils.data import (
15+
DataRequirementItem,
16+
)
17+
from deepmd.utils.version import (
18+
check_version_compatibility,
19+
)
20+
21+
22+
class DOSLoss(Loss):
23+
r"""Loss on DOS (density of states) for both local and global predictions.
24+
25+
Parameters
26+
----------
27+
starter_learning_rate : float
28+
The learning rate at the start of the training.
29+
numb_dos : int
30+
The number of DOS components.
31+
start_pref_dos : float
32+
The prefactor of global DOS loss at the start of the training.
33+
limit_pref_dos : float
34+
The prefactor of global DOS loss at the end of the training.
35+
start_pref_cdf : float
36+
The prefactor of global CDF loss at the start of the training.
37+
limit_pref_cdf : float
38+
The prefactor of global CDF loss at the end of the training.
39+
start_pref_ados : float
40+
The prefactor of atomic DOS loss at the start of the training.
41+
limit_pref_ados : float
42+
The prefactor of atomic DOS loss at the end of the training.
43+
start_pref_acdf : float
44+
The prefactor of atomic CDF loss at the start of the training.
45+
limit_pref_acdf : float
46+
The prefactor of atomic CDF loss at the end of the training.
47+
**kwargs
48+
Other keyword arguments.
49+
"""
50+
51+
def __init__(
52+
self,
53+
starter_learning_rate: float,
54+
numb_dos: int,
55+
start_pref_dos: float = 1.00,
56+
limit_pref_dos: float = 1.00,
57+
start_pref_cdf: float = 1000,
58+
limit_pref_cdf: float = 1.00,
59+
start_pref_ados: float = 0.0,
60+
limit_pref_ados: float = 0.0,
61+
start_pref_acdf: float = 0.0,
62+
limit_pref_acdf: float = 0.0,
63+
**kwargs: Any,
64+
) -> None:
65+
self.starter_learning_rate = starter_learning_rate
66+
self.numb_dos = numb_dos
67+
self.start_pref_dos = start_pref_dos
68+
self.limit_pref_dos = limit_pref_dos
69+
self.start_pref_cdf = start_pref_cdf
70+
self.limit_pref_cdf = limit_pref_cdf
71+
self.start_pref_ados = start_pref_ados
72+
self.limit_pref_ados = limit_pref_ados
73+
self.start_pref_acdf = start_pref_acdf
74+
self.limit_pref_acdf = limit_pref_acdf
75+
76+
assert (
77+
self.start_pref_dos >= 0.0
78+
and self.limit_pref_dos >= 0.0
79+
and self.start_pref_cdf >= 0.0
80+
and self.limit_pref_cdf >= 0.0
81+
and self.start_pref_ados >= 0.0
82+
and self.limit_pref_ados >= 0.0
83+
and self.start_pref_acdf >= 0.0
84+
and self.limit_pref_acdf >= 0.0
85+
), "Can not assign negative weight to `pref` and `pref_atomic`"
86+
87+
self.has_dos = start_pref_dos != 0.0 or limit_pref_dos != 0.0
88+
self.has_cdf = start_pref_cdf != 0.0 or limit_pref_cdf != 0.0
89+
self.has_ados = start_pref_ados != 0.0 or limit_pref_ados != 0.0
90+
self.has_acdf = start_pref_acdf != 0.0 or limit_pref_acdf != 0.0
91+
92+
assert self.has_dos or self.has_cdf or self.has_ados or self.has_acdf, (
93+
"Can not assign zero weight to all pref terms"
94+
)
95+
96+
def call(
97+
self,
98+
learning_rate: float,
99+
natoms: int,
100+
model_dict: dict[str, Array],
101+
label_dict: dict[str, Array],
102+
mae: bool = False,
103+
) -> tuple[Array, dict[str, Array]]:
104+
"""Calculate loss from model results and labeled results."""
105+
# Get array namespace from any available tensor
106+
first_key = next(iter(model_dict))
107+
xp = array_api_compat.array_namespace(model_dict[first_key])
108+
109+
coef = learning_rate / self.starter_learning_rate
110+
pref_dos = (
111+
self.limit_pref_dos + (self.start_pref_dos - self.limit_pref_dos) * coef
112+
)
113+
pref_cdf = (
114+
self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef
115+
)
116+
pref_ados = (
117+
self.limit_pref_ados + (self.start_pref_ados - self.limit_pref_ados) * coef
118+
)
119+
pref_acdf = (
120+
self.limit_pref_acdf + (self.start_pref_acdf - self.limit_pref_acdf) * coef
121+
)
122+
123+
loss = 0
124+
more_loss = {}
125+
126+
if self.has_ados and "atom_dos" in model_dict and "atom_dos" in label_dict:
127+
find_local = label_dict.get("find_atom_dos", 0.0)
128+
pref_ados = pref_ados * find_local
129+
local_pred = xp.reshape(model_dict["atom_dos"], (-1, natoms, self.numb_dos))
130+
local_label = xp.reshape(
131+
label_dict["atom_dos"], (-1, natoms, self.numb_dos)
132+
)
133+
diff = xp.reshape(local_pred - local_label, (-1, self.numb_dos))
134+
if "mask" in model_dict:
135+
mask = xp.reshape(model_dict["mask"], (-1,))
136+
mask_float = xp.astype(mask, diff.dtype)
137+
diff = diff * mask_float[:, None]
138+
n_valid = xp.sum(mask_float)
139+
l2_local_loss_dos = xp.sum(xp.square(diff)) / (n_valid * self.numb_dos)
140+
else:
141+
l2_local_loss_dos = xp.mean(xp.square(diff))
142+
loss += pref_ados * l2_local_loss_dos
143+
more_loss["rmse_local_dos"] = self.display_if_exist(
144+
xp.sqrt(l2_local_loss_dos), find_local
145+
)
146+
147+
if self.has_acdf and "atom_dos" in model_dict and "atom_dos" in label_dict:
148+
find_local = label_dict.get("find_atom_dos", 0.0)
149+
pref_acdf = pref_acdf * find_local
150+
local_pred_cdf = xp.cumulative_sum(
151+
xp.reshape(model_dict["atom_dos"], (-1, natoms, self.numb_dos)),
152+
axis=-1,
153+
)
154+
local_label_cdf = xp.cumulative_sum(
155+
xp.reshape(label_dict["atom_dos"], (-1, natoms, self.numb_dos)),
156+
axis=-1,
157+
)
158+
diff = xp.reshape(local_pred_cdf - local_label_cdf, (-1, self.numb_dos))
159+
if "mask" in model_dict:
160+
mask = xp.reshape(model_dict["mask"], (-1,))
161+
mask_float = xp.astype(mask, diff.dtype)
162+
diff = diff * mask_float[:, None]
163+
n_valid = xp.sum(mask_float)
164+
l2_local_loss_cdf = xp.sum(xp.square(diff)) / (n_valid * self.numb_dos)
165+
else:
166+
l2_local_loss_cdf = xp.mean(xp.square(diff))
167+
loss += pref_acdf * l2_local_loss_cdf
168+
more_loss["rmse_local_cdf"] = self.display_if_exist(
169+
xp.sqrt(l2_local_loss_cdf), find_local
170+
)
171+
172+
if self.has_dos and "dos" in model_dict and "dos" in label_dict:
173+
find_global = label_dict.get("find_dos", 0.0)
174+
pref_dos = pref_dos * find_global
175+
global_pred = xp.reshape(model_dict["dos"], (-1, self.numb_dos))
176+
global_label = xp.reshape(label_dict["dos"], (-1, self.numb_dos))
177+
diff = global_pred - global_label
178+
if "mask" in model_dict:
179+
atom_num = xp.sum(model_dict["mask"], axis=-1, keepdims=True)
180+
l2_global_loss_dos = xp.mean(
181+
xp.sum(xp.square(diff) * atom_num, axis=0) / xp.sum(atom_num)
182+
)
183+
atom_num = xp.mean(xp.astype(atom_num, diff.dtype))
184+
else:
185+
atom_num = natoms
186+
l2_global_loss_dos = xp.mean(xp.square(diff))
187+
loss += pref_dos * l2_global_loss_dos
188+
more_loss["rmse_global_dos"] = self.display_if_exist(
189+
xp.sqrt(l2_global_loss_dos) / atom_num, find_global
190+
)
191+
192+
if self.has_cdf and "dos" in model_dict and "dos" in label_dict:
193+
find_global = label_dict.get("find_dos", 0.0)
194+
pref_cdf = pref_cdf * find_global
195+
global_pred_cdf = xp.cumulative_sum(
196+
xp.reshape(model_dict["dos"], (-1, self.numb_dos)), axis=-1
197+
)
198+
global_label_cdf = xp.cumulative_sum(
199+
xp.reshape(label_dict["dos"], (-1, self.numb_dos)), axis=-1
200+
)
201+
diff = global_pred_cdf - global_label_cdf
202+
if "mask" in model_dict:
203+
atom_num = xp.sum(model_dict["mask"], axis=-1, keepdims=True)
204+
l2_global_loss_cdf = xp.mean(
205+
xp.sum(xp.square(diff) * atom_num, axis=0) / xp.sum(atom_num)
206+
)
207+
atom_num = xp.mean(xp.astype(atom_num, diff.dtype))
208+
else:
209+
atom_num = natoms
210+
l2_global_loss_cdf = xp.mean(xp.square(diff))
211+
loss += pref_cdf * l2_global_loss_cdf
212+
more_loss["rmse_global_cdf"] = self.display_if_exist(
213+
xp.sqrt(l2_global_loss_cdf) / atom_num, find_global
214+
)
215+
216+
more_loss["rmse"] = xp.sqrt(loss)
217+
return loss, more_loss
218+
219+
@property
220+
def label_requirement(self) -> list[DataRequirementItem]:
221+
"""Return data label requirements needed for this loss calculation."""
222+
label_requirement = []
223+
if self.has_ados or self.has_acdf:
224+
label_requirement.append(
225+
DataRequirementItem(
226+
"atom_dos",
227+
ndof=self.numb_dos,
228+
atomic=True,
229+
must=False,
230+
high_prec=False,
231+
)
232+
)
233+
if self.has_dos or self.has_cdf:
234+
label_requirement.append(
235+
DataRequirementItem(
236+
"dos",
237+
ndof=self.numb_dos,
238+
atomic=False,
239+
must=False,
240+
high_prec=False,
241+
)
242+
)
243+
return label_requirement
244+
245+
def serialize(self) -> dict:
246+
"""Serialize the loss module."""
247+
return {
248+
"@class": "DOSLoss",
249+
"@version": 1,
250+
"starter_learning_rate": self.starter_learning_rate,
251+
"numb_dos": self.numb_dos,
252+
"start_pref_dos": self.start_pref_dos,
253+
"limit_pref_dos": self.limit_pref_dos,
254+
"start_pref_cdf": self.start_pref_cdf,
255+
"limit_pref_cdf": self.limit_pref_cdf,
256+
"start_pref_ados": self.start_pref_ados,
257+
"limit_pref_ados": self.limit_pref_ados,
258+
"start_pref_acdf": self.start_pref_acdf,
259+
"limit_pref_acdf": self.limit_pref_acdf,
260+
}
261+
262+
@classmethod
263+
def deserialize(cls, data: dict) -> "DOSLoss":
264+
"""Deserialize the loss module."""
265+
data = data.copy()
266+
check_version_compatibility(data.pop("@version"), 1, 1)
267+
data.pop("@class")
268+
return cls(**data)

deepmd/dpmodel/loss/ener.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def call(
166166
natoms: int,
167167
model_dict: dict[str, Array],
168168
label_dict: dict[str, Array],
169+
mae: bool = False,
169170
) -> tuple[Array, dict[str, Array]]:
170171
"""Calculate loss from model results and labeled results."""
171172
energy = model_dict["energy"]
@@ -266,6 +267,11 @@ def call(
266267
raise NotImplementedError(
267268
f"Loss type {self.loss_func} is not implemented for energy loss."
268269
)
270+
if mae:
271+
mae_e = xp.mean(xp.abs(energy - energy_hat)) * atom_norm_ener
272+
more_loss["mae_e"] = self.display_if_exist(mae_e, find_energy)
273+
mae_e_all = xp.mean(xp.abs(energy - energy_hat))
274+
more_loss["mae_e_all"] = self.display_if_exist(mae_e_all, find_energy)
269275
if self.has_f:
270276
if self.loss_func == "mse":
271277
l2_force_loss = xp.mean(xp.square(diff_f))
@@ -304,6 +310,9 @@ def call(
304310
raise NotImplementedError(
305311
f"Loss type {self.loss_func} is not implemented for force loss."
306312
)
313+
if mae:
314+
mae_f = xp.mean(xp.abs(diff_f))
315+
more_loss["mae_f"] = self.display_if_exist(mae_f, find_force)
307316
if self.has_v:
308317
virial_reshape = xp.reshape(virial, (-1,))
309318
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
@@ -333,6 +342,9 @@ def call(
333342
raise NotImplementedError(
334343
f"Loss type {self.loss_func} is not implemented for virial loss."
335344
)
345+
if mae:
346+
mae_v = xp.mean(xp.abs(virial_hat_reshape - virial_reshape)) * atom_norm
347+
more_loss["mae_v"] = self.display_if_exist(mae_v, find_virial)
336348
if self.has_ae:
337349
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
338350
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))

0 commit comments

Comments
 (0)