Skip to content

Commit 3f52fa9

Browse files
authored
feat(dp, pt): add force l2 norm loss & mae loss (#5294)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Introduced loss_func ("mse" or "mae") to select MSE vs MAE for energy/force/virial/atom losses. * Added f_use_norm to enable vector‑norm MAE behavior when allowed. * **Validation** * Enforced that f_use_norm is only valid when use_huber is enabled or loss_func="mae"; invalid combos are rejected. * **Tests** * Extended loss tests and skipping logic to cover loss_func and f_use_norm combinations. * **Documentation** * Updated docs to describe loss_func and resulting metric names (rmse_* vs mae_*). * **Chores** * New options are persisted in serialized configurations. * **Notes** * Some backends currently only support "mse" (MAE not yet available everywhere). <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent b2c8511 commit 3f52fa9

File tree

9 files changed

+635
-185
lines changed

9 files changed

+635
-185
lines changed

deepmd/dpmodel/loss/ener.py

Lines changed: 204 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,64 @@ def custom_huber_loss(predictions: Array, targets: Array, delta: float = 1.0) ->
3030

3131

3232
class EnergyLoss(Loss):
33+
r"""Construct a layer to compute loss on energy, force and virial.
34+
35+
Parameters
36+
----------
37+
starter_learning_rate : float
38+
The learning rate at the start of the training.
39+
start_pref_e : float
40+
The prefactor of energy loss at the start of the training.
41+
limit_pref_e : float
42+
The prefactor of energy loss at the end of the training.
43+
start_pref_f : float
44+
The prefactor of force loss at the start of the training.
45+
limit_pref_f : float
46+
The prefactor of force loss at the end of the training.
47+
start_pref_v : float
48+
The prefactor of virial loss at the start of the training.
49+
limit_pref_v : float
50+
The prefactor of virial loss at the end of the training.
51+
start_pref_ae : float
52+
The prefactor of atomic energy loss at the start of the training.
53+
limit_pref_ae : float
54+
The prefactor of atomic energy loss at the end of the training.
55+
start_pref_pf : float
56+
The prefactor of atomic prefactor force loss at the start of the training.
57+
limit_pref_pf : float
58+
The prefactor of atomic prefactor force loss at the end of the training.
59+
relative_f : float
60+
If provided, relative force error will be used in the loss. The difference
61+
of force will be normalized by the magnitude of the force in the label with
62+
a shift given by relative_f
63+
enable_atom_ener_coeff : bool
64+
if true, the energy will be computed as \sum_i c_i E_i
65+
start_pref_gf : float
66+
The prefactor of generalized force loss at the start of the training.
67+
limit_pref_gf : float
68+
The prefactor of generalized force loss at the end of the training.
69+
numb_generalized_coord : int
70+
The dimension of generalized coordinates.
71+
use_huber : bool
72+
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
73+
The loss function smoothly transitions between L2 and L1 loss:
74+
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
75+
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
76+
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
77+
huber_delta : float
78+
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
79+
loss_func : str
80+
Loss function type for energy, force, and virial terms.
81+
Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
82+
MAE loss is less sensitive to outliers compared to MSE loss.
83+
Future extensions may support additional loss types.
84+
f_use_norm : bool
85+
If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
86+
Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
87+
**kwargs
88+
Other keyword arguments.
89+
"""
90+
3391
def __init__(
3492
self,
3593
starter_learning_rate: float,
@@ -50,8 +108,18 @@ def __init__(
50108
numb_generalized_coord: int = 0,
51109
use_huber: bool = False,
52110
huber_delta: float = 0.01,
111+
loss_func: str = "mse",
112+
f_use_norm: bool = False,
53113
**kwargs: Any,
54114
) -> None:
115+
# Validate loss_func
116+
valid_loss_funcs = ["mse", "mae"]
117+
if loss_func not in valid_loss_funcs:
118+
raise ValueError(
119+
f"Invalid loss_func '{loss_func}'. Must be one of {valid_loss_funcs}."
120+
)
121+
122+
self.loss_func = loss_func
55123
self.starter_learning_rate = starter_learning_rate
56124
self.start_pref_e = start_pref_e
57125
self.limit_pref_e = limit_pref_e
@@ -80,6 +148,11 @@ def __init__(
80148
)
81149
self.use_huber = use_huber
82150
self.huber_delta = huber_delta
151+
self.f_use_norm = f_use_norm
152+
if self.f_use_norm and not (self.use_huber or self.loss_func == "mae"):
153+
raise RuntimeError(
154+
"f_use_norm can only be True when use_huber or loss_func='mae'."
155+
)
83156
if self.use_huber and (
84157
self.has_pf or self.has_gf or self.relative_f is not None
85158
):
@@ -169,78 +242,152 @@ def call(
169242
loss = 0
170243
more_loss = {}
171244
if self.has_e:
172-
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
173-
if not self.use_huber:
174-
loss += atom_norm_ener * (pref_e * l2_ener_loss)
245+
if self.loss_func == "mse":
246+
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
247+
if not self.use_huber:
248+
loss += atom_norm_ener * (pref_e * l2_ener_loss)
249+
else:
250+
l_huber_loss = custom_huber_loss(
251+
atom_norm_ener * energy,
252+
atom_norm_ener * energy_hat,
253+
delta=self.huber_delta,
254+
)
255+
loss += pref_e * l_huber_loss
256+
more_loss["rmse_e"] = self.display_if_exist(
257+
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy
258+
)
259+
elif self.loss_func == "mae":
260+
l1_ener_loss = xp.mean(xp.abs(energy - energy_hat))
261+
loss += atom_norm_ener * (pref_e * l1_ener_loss)
262+
more_loss["mae_e"] = self.display_if_exist(
263+
l1_ener_loss * atom_norm_ener, find_energy
264+
)
175265
else:
176-
l_huber_loss = custom_huber_loss(
177-
atom_norm_ener * energy,
178-
atom_norm_ener * energy_hat,
179-
delta=self.huber_delta,
266+
raise NotImplementedError(
267+
f"Loss type {self.loss_func} is not implemented for energy loss."
180268
)
181-
loss += pref_e * l_huber_loss
182-
more_loss["rmse_e"] = self.display_if_exist(
183-
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy
184-
)
185269
if self.has_f:
186-
l2_force_loss = xp.mean(xp.square(diff_f))
187-
if not self.use_huber:
188-
loss += pref_f * l2_force_loss
270+
if self.loss_func == "mse":
271+
l2_force_loss = xp.mean(xp.square(diff_f))
272+
if not self.use_huber:
273+
loss += pref_f * l2_force_loss
274+
else:
275+
if not self.f_use_norm:
276+
l_huber_loss = custom_huber_loss(
277+
xp.reshape(force, (-1,)),
278+
xp.reshape(force_hat, (-1,)),
279+
delta=self.huber_delta,
280+
)
281+
else:
282+
force_diff_3 = xp.reshape(force_hat - force, (-1, 3))
283+
force_diff_norm = xp.reshape(
284+
xp.linalg.vector_norm(force_diff_3, axis=1), (-1, 1)
285+
)
286+
l_huber_loss = custom_huber_loss(
287+
force_diff_norm,
288+
xp.zeros_like(force_diff_norm),
289+
delta=self.huber_delta,
290+
)
291+
loss += pref_f * l_huber_loss
292+
more_loss["rmse_f"] = self.display_if_exist(
293+
xp.sqrt(l2_force_loss), find_force
294+
)
295+
elif self.loss_func == "mae":
296+
if not self.f_use_norm:
297+
l1_force_loss = xp.mean(xp.abs(diff_f))
298+
else:
299+
force_diff_3 = xp.reshape(force_hat - force, (-1, 3))
300+
l1_force_loss = xp.mean(xp.linalg.vector_norm(force_diff_3, axis=1))
301+
loss += pref_f * l1_force_loss
302+
more_loss["mae_f"] = self.display_if_exist(l1_force_loss, find_force)
189303
else:
190-
l_huber_loss = custom_huber_loss(
191-
xp.reshape(force, (-1,)),
192-
xp.reshape(force_hat, (-1,)),
193-
delta=self.huber_delta,
304+
raise NotImplementedError(
305+
f"Loss type {self.loss_func} is not implemented for force loss."
194306
)
195-
loss += pref_f * l_huber_loss
196-
more_loss["rmse_f"] = self.display_if_exist(
197-
xp.sqrt(l2_force_loss), find_force
198-
)
199307
if self.has_v:
200308
virial_reshape = xp.reshape(virial, (-1,))
201309
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
202-
l2_virial_loss = xp.mean(
203-
xp.square(virial_hat_reshape - virial_reshape),
204-
)
205-
if not self.use_huber:
206-
loss += atom_norm * (pref_v * l2_virial_loss)
310+
if self.loss_func == "mse":
311+
l2_virial_loss = xp.mean(
312+
xp.square(virial_hat_reshape - virial_reshape),
313+
)
314+
if not self.use_huber:
315+
loss += atom_norm * (pref_v * l2_virial_loss)
316+
else:
317+
l_huber_loss = custom_huber_loss(
318+
atom_norm * virial_reshape,
319+
atom_norm * virial_hat_reshape,
320+
delta=self.huber_delta,
321+
)
322+
loss += pref_v * l_huber_loss
323+
more_loss["rmse_v"] = self.display_if_exist(
324+
xp.sqrt(l2_virial_loss) * atom_norm, find_virial
325+
)
326+
elif self.loss_func == "mae":
327+
l1_virial_loss = xp.mean(xp.abs(virial_hat_reshape - virial_reshape))
328+
loss += atom_norm * (pref_v * l1_virial_loss)
329+
more_loss["mae_v"] = self.display_if_exist(
330+
l1_virial_loss * atom_norm, find_virial
331+
)
207332
else:
208-
l_huber_loss = custom_huber_loss(
209-
atom_norm * virial_reshape,
210-
atom_norm * virial_hat_reshape,
211-
delta=self.huber_delta,
333+
raise NotImplementedError(
334+
f"Loss type {self.loss_func} is not implemented for virial loss."
212335
)
213-
loss += pref_v * l_huber_loss
214-
more_loss["rmse_v"] = self.display_if_exist(
215-
xp.sqrt(l2_virial_loss) * atom_norm, find_virial
216-
)
217336
if self.has_ae:
218337
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
219338
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
220-
l2_atom_ener_loss = xp.mean(
221-
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
222-
)
223-
if not self.use_huber:
224-
loss += pref_ae * l2_atom_ener_loss
339+
340+
if self.loss_func == "mse":
341+
l2_atom_ener_loss = xp.mean(
342+
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
343+
)
344+
if not self.use_huber:
345+
loss += pref_ae * l2_atom_ener_loss
346+
else:
347+
l_huber_loss = custom_huber_loss(
348+
atom_ener_reshape,
349+
atom_ener_hat_reshape,
350+
delta=self.huber_delta,
351+
)
352+
loss += pref_ae * l_huber_loss
353+
more_loss["rmse_ae"] = self.display_if_exist(
354+
xp.sqrt(l2_atom_ener_loss), find_atom_ener
355+
)
356+
elif self.loss_func == "mae":
357+
l1_atom_ener_loss = xp.mean(
358+
xp.abs(atom_ener_hat_reshape - atom_ener_reshape)
359+
)
360+
loss += pref_ae * l1_atom_ener_loss
361+
more_loss["mae_ae"] = self.display_if_exist(
362+
l1_atom_ener_loss, find_atom_ener
363+
)
225364
else:
226-
l_huber_loss = custom_huber_loss(
227-
atom_ener_reshape,
228-
atom_ener_hat_reshape,
229-
delta=self.huber_delta,
365+
raise NotImplementedError(
366+
f"Loss type {self.loss_func} is not implemented for atomic energy loss."
230367
)
231-
loss += pref_ae * l_huber_loss
232-
more_loss["rmse_ae"] = self.display_if_exist(
233-
xp.sqrt(l2_atom_ener_loss), find_atom_ener
234-
)
235368
if self.has_pf:
236369
atom_pref_reshape = xp.reshape(atom_pref, (-1,))
237-
l2_pref_force_loss = xp.mean(
238-
xp.multiply(xp.square(diff_f), atom_pref_reshape),
239-
)
240-
loss += pref_pf * l2_pref_force_loss
241-
more_loss["rmse_pf"] = self.display_if_exist(
242-
xp.sqrt(l2_pref_force_loss), find_atom_pref
243-
)
370+
371+
if self.loss_func == "mse":
372+
l2_pref_force_loss = xp.mean(
373+
xp.multiply(xp.square(diff_f), atom_pref_reshape),
374+
)
375+
loss += pref_pf * l2_pref_force_loss
376+
more_loss["rmse_pf"] = self.display_if_exist(
377+
xp.sqrt(l2_pref_force_loss), find_atom_pref
378+
)
379+
elif self.loss_func == "mae":
380+
l1_pref_force_loss = xp.mean(
381+
xp.multiply(xp.abs(diff_f), atom_pref_reshape)
382+
)
383+
loss += pref_pf * l1_pref_force_loss
384+
more_loss["mae_pf"] = self.display_if_exist(
385+
l1_pref_force_loss, find_atom_pref
386+
)
387+
else:
388+
raise NotImplementedError(
389+
f"Loss type {self.loss_func} is not implemented for atom prefactor force loss."
390+
)
244391
if self.has_gf:
245392
find_drdq = label_dict["find_drdq"]
246393
drdq = label_dict["drdq"]
@@ -372,6 +519,8 @@ def serialize(self) -> dict:
372519
"numb_generalized_coord": self.numb_generalized_coord,
373520
"use_huber": self.use_huber,
374521
"huber_delta": self.huber_delta,
522+
"loss_func": self.loss_func,
523+
"f_use_norm": self.f_use_norm,
375524
}
376525

377526
@classmethod

0 commit comments

Comments
 (0)