Skip to content

Commit 3341fd1

Browse files
committed
use str for loss func
1 parent 945d423 commit 3341fd1

9 files changed

Lines changed: 446 additions & 176 deletions

File tree

deepmd/dpmodel/loss/ener.py

Lines changed: 135 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,75 @@ def __init__(
5050
numb_generalized_coord: int = 0,
5151
use_huber: bool = False,
5252
huber_delta: float = 0.01,
53-
use_mae_loss: bool = False,
53+
loss_func: str = "mse",
5454
f_use_norm: bool = False,
5555
**kwargs: Any,
5656
) -> None:
57+
r"""Construct a layer to compute loss on energy, force and virial.
58+
59+
Parameters
60+
----------
61+
starter_learning_rate : float
62+
The learning rate at the start of the training.
63+
start_pref_e : float
64+
The prefactor of energy loss at the start of the training.
65+
limit_pref_e : float
66+
The prefactor of energy loss at the end of the training.
67+
start_pref_f : float
68+
The prefactor of force loss at the start of the training.
69+
limit_pref_f : float
70+
The prefactor of force loss at the end of the training.
71+
start_pref_v : float
72+
The prefactor of virial loss at the start of the training.
73+
limit_pref_v : float
74+
The prefactor of virial loss at the end of the training.
75+
start_pref_ae : float
76+
The prefactor of atomic energy loss at the start of the training.
77+
limit_pref_ae : float
78+
The prefactor of atomic energy loss at the end of the training.
79+
start_pref_pf : float
80+
The prefactor of atomic prefactor force loss at the start of the training.
81+
limit_pref_pf : float
82+
The prefactor of atomic prefactor force loss at the end of the training.
83+
relative_f : float
84+
If provided, relative force error will be used in the loss. The difference
85+
of force will be normalized by the magnitude of the force in the label with
86+
a shift given by relative_f
87+
enable_atom_ener_coeff : bool
88+
if true, the energy will be computed as \sum_i c_i E_i
89+
start_pref_gf : float
90+
The prefactor of generalized force loss at the start of the training.
91+
limit_pref_gf : float
92+
The prefactor of generalized force loss at the end of the training.
93+
numb_generalized_coord : int
94+
The dimension of generalized coordinates.
95+
use_huber : bool
96+
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
97+
The loss function smoothly transitions between L2 and L1 loss:
98+
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
99+
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
100+
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
101+
huber_delta : float
102+
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
103+
loss_func : str
104+
Loss function type for energy, force, and virial terms.
105+
Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
106+
MAE loss is less sensitive to outliers compared to MSE loss.
107+
Future extensions may support additional loss types.
108+
f_use_norm : bool
109+
If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
110+
Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
111+
**kwargs
112+
Other keyword arguments.
113+
"""
114+
# Validate loss_func
115+
valid_loss_funcs = ["mse", "mae"]
116+
if loss_func not in valid_loss_funcs:
117+
raise ValueError(
118+
f"Invalid loss_func '{loss_func}'. Must be one of {valid_loss_funcs}."
119+
)
120+
121+
self.loss_func = loss_func
57122
self.starter_learning_rate = starter_learning_rate
58123
self.start_pref_e = start_pref_e
59124
self.limit_pref_e = limit_pref_e
@@ -82,11 +147,10 @@ def __init__(
82147
)
83148
self.use_huber = use_huber
84149
self.huber_delta = huber_delta
85-
self.use_mae_loss = use_mae_loss
86150
self.f_use_norm = f_use_norm
87-
if self.f_use_norm and not (self.use_huber or self.use_mae_loss):
151+
if self.f_use_norm and not (self.use_huber or self.loss_func == "mae"):
88152
raise RuntimeError(
89-
"f_use_norm can only be True when use_huber or use_mae_loss is True."
153+
"f_use_norm can only be True when use_huber or loss_func='mae'."
90154
)
91155
if self.use_huber and (
92156
self.has_pf or self.has_gf or self.relative_f is not None
@@ -177,7 +241,7 @@ def call(
177241
loss = 0
178242
more_loss = {}
179243
if self.has_e:
180-
if not self.use_mae_loss:
244+
if self.loss_func == "mse":
181245
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
182246
if not self.use_huber:
183247
loss += atom_norm_ener * (pref_e * l2_ener_loss)
@@ -191,14 +255,18 @@ def call(
191255
more_loss["rmse_e"] = self.display_if_exist(
192256
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy
193257
)
194-
else:
258+
elif self.loss_func == "mae":
195259
l1_ener_loss = xp.mean(xp.abs(energy - energy_hat))
196260
loss += atom_norm_ener * (pref_e * l1_ener_loss)
197261
more_loss["mae_e"] = self.display_if_exist(
198262
l1_ener_loss * atom_norm_ener, find_energy
199263
)
264+
else:
265+
raise NotImplementedError(
266+
f"Loss type {self.loss_func} is not implemented for energy loss."
267+
)
200268
if self.has_f:
201-
if not self.use_mae_loss:
269+
if self.loss_func == "mse":
202270
l2_force_loss = xp.mean(xp.square(diff_f))
203271
if not self.use_huber:
204272
loss += pref_f * l2_force_loss
@@ -223,18 +291,22 @@ def call(
223291
more_loss["rmse_f"] = self.display_if_exist(
224292
xp.sqrt(l2_force_loss), find_force
225293
)
226-
else:
294+
elif self.loss_func == "mae":
227295
if not self.f_use_norm:
228296
l1_force_loss = xp.mean(xp.abs(diff_f))
229297
else:
230298
force_diff_3 = xp.reshape(force_hat - force, (-1, 3))
231299
l1_force_loss = xp.mean(xp.linalg.vector_norm(force_diff_3, axis=1))
232300
loss += pref_f * l1_force_loss
233301
more_loss["mae_f"] = self.display_if_exist(l1_force_loss, find_force)
302+
else:
303+
raise NotImplementedError(
304+
f"Loss type {self.loss_func} is not implemented for force loss."
305+
)
234306
if self.has_v:
235307
virial_reshape = xp.reshape(virial, (-1,))
236308
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
237-
if not self.use_mae_loss:
309+
if self.loss_func == "mse":
238310
l2_virial_loss = xp.mean(
239311
xp.square(virial_hat_reshape - virial_reshape),
240312
)
@@ -250,39 +322,71 @@ def call(
250322
more_loss["rmse_v"] = self.display_if_exist(
251323
xp.sqrt(l2_virial_loss) * atom_norm, find_virial
252324
)
253-
else:
325+
elif self.loss_func == "mae":
254326
l1_virial_loss = xp.mean(xp.abs(virial_hat_reshape - virial_reshape))
255327
loss += atom_norm * (pref_v * l1_virial_loss)
256328
more_loss["mae_v"] = self.display_if_exist(
257329
l1_virial_loss * atom_norm, find_virial
258330
)
331+
else:
332+
raise NotImplementedError(
333+
f"Loss type {self.loss_func} is not implemented for virial loss."
334+
)
259335
if self.has_ae:
260336
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
261337
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
262-
l2_atom_ener_loss = xp.mean(
263-
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
264-
)
265-
if not self.use_huber:
266-
loss += pref_ae * l2_atom_ener_loss
338+
339+
if self.loss_func == "mse":
340+
l2_atom_ener_loss = xp.mean(
341+
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
342+
)
343+
if not self.use_huber:
344+
loss += pref_ae * l2_atom_ener_loss
345+
else:
346+
l_huber_loss = custom_huber_loss(
347+
atom_ener_reshape,
348+
atom_ener_hat_reshape,
349+
delta=self.huber_delta,
350+
)
351+
loss += pref_ae * l_huber_loss
352+
more_loss["rmse_ae"] = self.display_if_exist(
353+
xp.sqrt(l2_atom_ener_loss), find_atom_ener
354+
)
355+
elif self.loss_func == "mae":
356+
l1_atom_ener_loss = xp.mean(
357+
xp.abs(atom_ener_hat_reshape - atom_ener_reshape)
358+
)
359+
loss += pref_ae * l1_atom_ener_loss
360+
more_loss["mae_ae"] = self.display_if_exist(
361+
l1_atom_ener_loss, find_atom_ener
362+
)
267363
else:
268-
l_huber_loss = custom_huber_loss(
269-
atom_ener_reshape,
270-
atom_ener_hat_reshape,
271-
delta=self.huber_delta,
364+
raise NotImplementedError(
365+
f"Loss type {self.loss_func} is not implemented for atomic energy loss."
272366
)
273-
loss += pref_ae * l_huber_loss
274-
more_loss["rmse_ae"] = self.display_if_exist(
275-
xp.sqrt(l2_atom_ener_loss), find_atom_ener
276-
)
277367
if self.has_pf:
278368
atom_pref_reshape = xp.reshape(atom_pref, (-1,))
279-
l2_pref_force_loss = xp.mean(
280-
xp.multiply(xp.square(diff_f), atom_pref_reshape),
281-
)
282-
loss += pref_pf * l2_pref_force_loss
283-
more_loss["rmse_pf"] = self.display_if_exist(
284-
xp.sqrt(l2_pref_force_loss), find_atom_pref
285-
)
369+
370+
if self.loss_func == "mse":
371+
l2_pref_force_loss = xp.mean(
372+
xp.multiply(xp.square(diff_f), atom_pref_reshape),
373+
)
374+
loss += pref_pf * l2_pref_force_loss
375+
more_loss["rmse_pf"] = self.display_if_exist(
376+
xp.sqrt(l2_pref_force_loss), find_atom_pref
377+
)
378+
elif self.loss_func == "mae":
379+
l1_pref_force_loss = xp.mean(
380+
xp.multiply(xp.abs(diff_f), atom_pref_reshape)
381+
)
382+
loss += pref_pf * l1_pref_force_loss
383+
more_loss["mae_pf"] = self.display_if_exist(
384+
l1_pref_force_loss, find_atom_pref
385+
)
386+
else:
387+
raise NotImplementedError(
388+
f"Loss type {self.loss_func} is not implemented for atom prefactor force loss."
389+
)
286390
if self.has_gf:
287391
find_drdq = label_dict["find_drdq"]
288392
drdq = label_dict["drdq"]
@@ -413,7 +517,7 @@ def serialize(self) -> dict:
413517
"numb_generalized_coord": self.numb_generalized_coord,
414518
"use_huber": self.use_huber,
415519
"huber_delta": self.huber_delta,
416-
"use_mae_loss": self.use_mae_loss,
520+
"loss_func": self.loss_func,
417521
"f_use_norm": self.f_use_norm,
418522
}
419523

0 commit comments

Comments
 (0)