@@ -30,6 +30,64 @@ def custom_huber_loss(predictions: Array, targets: Array, delta: float = 1.0) ->
3030
3131
3232class EnergyLoss (Loss ):
33+ r"""Construct a layer to compute loss on energy, force and virial.
34+
35+ Parameters
36+ ----------
37+ starter_learning_rate : float
38+ The learning rate at the start of the training.
39+ start_pref_e : float
40+ The prefactor of energy loss at the start of the training.
41+ limit_pref_e : float
42+ The prefactor of energy loss at the end of the training.
43+ start_pref_f : float
44+ The prefactor of force loss at the start of the training.
45+ limit_pref_f : float
46+ The prefactor of force loss at the end of the training.
47+ start_pref_v : float
48+ The prefactor of virial loss at the start of the training.
49+ limit_pref_v : float
50+ The prefactor of virial loss at the end of the training.
51+ start_pref_ae : float
52+ The prefactor of atomic energy loss at the start of the training.
53+ limit_pref_ae : float
54+ The prefactor of atomic energy loss at the end of the training.
55+ start_pref_pf : float
56+ The prefactor of atomic prefactor force loss at the start of the training.
57+ limit_pref_pf : float
58+ The prefactor of atomic prefactor force loss at the end of the training.
59+ relative_f : float
60+ If provided, relative force error will be used in the loss. The difference
61+ of force will be normalized by the magnitude of the force in the label with
62+ a shift given by relative_f
63+ enable_atom_ener_coeff : bool
64+ if true, the energy will be computed as \sum_i c_i E_i
65+ start_pref_gf : float
66+ The prefactor of generalized force loss at the start of the training.
67+ limit_pref_gf : float
68+ The prefactor of generalized force loss at the end of the training.
69+ numb_generalized_coord : int
70+ The dimension of generalized coordinates.
71+ use_huber : bool
72+ Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
73+ The loss function smoothly transitions between L2 and L1 loss:
74+ - For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
75+ - For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
76+ Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
77+ huber_delta : float
78+ The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
79+ loss_func : str
80+ Loss function type for energy, force, and virial terms.
81+ Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
82+ MAE loss is less sensitive to outliers compared to MSE loss.
83+ Future extensions may support additional loss types.
84+ f_use_norm : bool
85+ If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
86+ Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
87+ **kwargs
88+ Other keyword arguments.
89+ """
90+
3391 def __init__ (
3492 self ,
3593 starter_learning_rate : float ,
@@ -50,8 +108,18 @@ def __init__(
50108 numb_generalized_coord : int = 0 ,
51109 use_huber : bool = False ,
52110 huber_delta : float = 0.01 ,
111+ loss_func : str = "mse" ,
112+ f_use_norm : bool = False ,
53113 ** kwargs : Any ,
54114 ) -> None :
115+ # Validate loss_func
116+ valid_loss_funcs = ["mse" , "mae" ]
117+ if loss_func not in valid_loss_funcs :
118+ raise ValueError (
119+ f"Invalid loss_func '{ loss_func } '. Must be one of { valid_loss_funcs } ."
120+ )
121+
122+ self .loss_func = loss_func
55123 self .starter_learning_rate = starter_learning_rate
56124 self .start_pref_e = start_pref_e
57125 self .limit_pref_e = limit_pref_e
@@ -80,6 +148,11 @@ def __init__(
80148 )
81149 self .use_huber = use_huber
82150 self .huber_delta = huber_delta
151+ self .f_use_norm = f_use_norm
152+ if self .f_use_norm and not (self .use_huber or self .loss_func == "mae" ):
153+ raise RuntimeError (
154+ "f_use_norm can only be True when use_huber or loss_func='mae'."
155+ )
83156 if self .use_huber and (
84157 self .has_pf or self .has_gf or self .relative_f is not None
85158 ):
@@ -169,78 +242,152 @@ def call(
169242 loss = 0
170243 more_loss = {}
171244 if self .has_e :
172- l2_ener_loss = xp .mean (xp .square (energy - energy_hat ))
173- if not self .use_huber :
174- loss += atom_norm_ener * (pref_e * l2_ener_loss )
245+ if self .loss_func == "mse" :
246+ l2_ener_loss = xp .mean (xp .square (energy - energy_hat ))
247+ if not self .use_huber :
248+ loss += atom_norm_ener * (pref_e * l2_ener_loss )
249+ else :
250+ l_huber_loss = custom_huber_loss (
251+ atom_norm_ener * energy ,
252+ atom_norm_ener * energy_hat ,
253+ delta = self .huber_delta ,
254+ )
255+ loss += pref_e * l_huber_loss
256+ more_loss ["rmse_e" ] = self .display_if_exist (
257+ xp .sqrt (l2_ener_loss ) * atom_norm_ener , find_energy
258+ )
259+ elif self .loss_func == "mae" :
260+ l1_ener_loss = xp .mean (xp .abs (energy - energy_hat ))
261+ loss += atom_norm_ener * (pref_e * l1_ener_loss )
262+ more_loss ["mae_e" ] = self .display_if_exist (
263+ l1_ener_loss * atom_norm_ener , find_energy
264+ )
175265 else :
176- l_huber_loss = custom_huber_loss (
177- atom_norm_ener * energy ,
178- atom_norm_ener * energy_hat ,
179- delta = self .huber_delta ,
266+ raise NotImplementedError (
267+ f"Loss type { self .loss_func } is not implemented for energy loss."
180268 )
181- loss += pref_e * l_huber_loss
182- more_loss ["rmse_e" ] = self .display_if_exist (
183- xp .sqrt (l2_ener_loss ) * atom_norm_ener , find_energy
184- )
185269 if self .has_f :
186- l2_force_loss = xp .mean (xp .square (diff_f ))
187- if not self .use_huber :
188- loss += pref_f * l2_force_loss
270+ if self .loss_func == "mse" :
271+ l2_force_loss = xp .mean (xp .square (diff_f ))
272+ if not self .use_huber :
273+ loss += pref_f * l2_force_loss
274+ else :
275+ if not self .f_use_norm :
276+ l_huber_loss = custom_huber_loss (
277+ xp .reshape (force , (- 1 ,)),
278+ xp .reshape (force_hat , (- 1 ,)),
279+ delta = self .huber_delta ,
280+ )
281+ else :
282+ force_diff_3 = xp .reshape (force_hat - force , (- 1 , 3 ))
283+ force_diff_norm = xp .reshape (
284+ xp .linalg .vector_norm (force_diff_3 , axis = 1 ), (- 1 , 1 )
285+ )
286+ l_huber_loss = custom_huber_loss (
287+ force_diff_norm ,
288+ xp .zeros_like (force_diff_norm ),
289+ delta = self .huber_delta ,
290+ )
291+ loss += pref_f * l_huber_loss
292+ more_loss ["rmse_f" ] = self .display_if_exist (
293+ xp .sqrt (l2_force_loss ), find_force
294+ )
295+ elif self .loss_func == "mae" :
296+ if not self .f_use_norm :
297+ l1_force_loss = xp .mean (xp .abs (diff_f ))
298+ else :
299+ force_diff_3 = xp .reshape (force_hat - force , (- 1 , 3 ))
300+ l1_force_loss = xp .mean (xp .linalg .vector_norm (force_diff_3 , axis = 1 ))
301+ loss += pref_f * l1_force_loss
302+ more_loss ["mae_f" ] = self .display_if_exist (l1_force_loss , find_force )
189303 else :
190- l_huber_loss = custom_huber_loss (
191- xp .reshape (force , (- 1 ,)),
192- xp .reshape (force_hat , (- 1 ,)),
193- delta = self .huber_delta ,
304+ raise NotImplementedError (
305+ f"Loss type { self .loss_func } is not implemented for force loss."
194306 )
195- loss += pref_f * l_huber_loss
196- more_loss ["rmse_f" ] = self .display_if_exist (
197- xp .sqrt (l2_force_loss ), find_force
198- )
199307 if self .has_v :
200308 virial_reshape = xp .reshape (virial , (- 1 ,))
201309 virial_hat_reshape = xp .reshape (virial_hat , (- 1 ,))
202- l2_virial_loss = xp .mean (
203- xp .square (virial_hat_reshape - virial_reshape ),
204- )
205- if not self .use_huber :
206- loss += atom_norm * (pref_v * l2_virial_loss )
310+ if self .loss_func == "mse" :
311+ l2_virial_loss = xp .mean (
312+ xp .square (virial_hat_reshape - virial_reshape ),
313+ )
314+ if not self .use_huber :
315+ loss += atom_norm * (pref_v * l2_virial_loss )
316+ else :
317+ l_huber_loss = custom_huber_loss (
318+ atom_norm * virial_reshape ,
319+ atom_norm * virial_hat_reshape ,
320+ delta = self .huber_delta ,
321+ )
322+ loss += pref_v * l_huber_loss
323+ more_loss ["rmse_v" ] = self .display_if_exist (
324+ xp .sqrt (l2_virial_loss ) * atom_norm , find_virial
325+ )
326+ elif self .loss_func == "mae" :
327+ l1_virial_loss = xp .mean (xp .abs (virial_hat_reshape - virial_reshape ))
328+ loss += atom_norm * (pref_v * l1_virial_loss )
329+ more_loss ["mae_v" ] = self .display_if_exist (
330+ l1_virial_loss * atom_norm , find_virial
331+ )
207332 else :
208- l_huber_loss = custom_huber_loss (
209- atom_norm * virial_reshape ,
210- atom_norm * virial_hat_reshape ,
211- delta = self .huber_delta ,
333+ raise NotImplementedError (
334+ f"Loss type { self .loss_func } is not implemented for virial loss."
212335 )
213- loss += pref_v * l_huber_loss
214- more_loss ["rmse_v" ] = self .display_if_exist (
215- xp .sqrt (l2_virial_loss ) * atom_norm , find_virial
216- )
217336 if self .has_ae :
218337 atom_ener_reshape = xp .reshape (atom_ener , (- 1 ,))
219338 atom_ener_hat_reshape = xp .reshape (atom_ener_hat , (- 1 ,))
220- l2_atom_ener_loss = xp .mean (
221- xp .square (atom_ener_hat_reshape - atom_ener_reshape ),
222- )
223- if not self .use_huber :
224- loss += pref_ae * l2_atom_ener_loss
339+
340+ if self .loss_func == "mse" :
341+ l2_atom_ener_loss = xp .mean (
342+ xp .square (atom_ener_hat_reshape - atom_ener_reshape ),
343+ )
344+ if not self .use_huber :
345+ loss += pref_ae * l2_atom_ener_loss
346+ else :
347+ l_huber_loss = custom_huber_loss (
348+ atom_ener_reshape ,
349+ atom_ener_hat_reshape ,
350+ delta = self .huber_delta ,
351+ )
352+ loss += pref_ae * l_huber_loss
353+ more_loss ["rmse_ae" ] = self .display_if_exist (
354+ xp .sqrt (l2_atom_ener_loss ), find_atom_ener
355+ )
356+ elif self .loss_func == "mae" :
357+ l1_atom_ener_loss = xp .mean (
358+ xp .abs (atom_ener_hat_reshape - atom_ener_reshape )
359+ )
360+ loss += pref_ae * l1_atom_ener_loss
361+ more_loss ["mae_ae" ] = self .display_if_exist (
362+ l1_atom_ener_loss , find_atom_ener
363+ )
225364 else :
226- l_huber_loss = custom_huber_loss (
227- atom_ener_reshape ,
228- atom_ener_hat_reshape ,
229- delta = self .huber_delta ,
365+ raise NotImplementedError (
366+ f"Loss type { self .loss_func } is not implemented for atomic energy loss."
230367 )
231- loss += pref_ae * l_huber_loss
232- more_loss ["rmse_ae" ] = self .display_if_exist (
233- xp .sqrt (l2_atom_ener_loss ), find_atom_ener
234- )
235368 if self .has_pf :
236369 atom_pref_reshape = xp .reshape (atom_pref , (- 1 ,))
237- l2_pref_force_loss = xp .mean (
238- xp .multiply (xp .square (diff_f ), atom_pref_reshape ),
239- )
240- loss += pref_pf * l2_pref_force_loss
241- more_loss ["rmse_pf" ] = self .display_if_exist (
242- xp .sqrt (l2_pref_force_loss ), find_atom_pref
243- )
370+
371+ if self .loss_func == "mse" :
372+ l2_pref_force_loss = xp .mean (
373+ xp .multiply (xp .square (diff_f ), atom_pref_reshape ),
374+ )
375+ loss += pref_pf * l2_pref_force_loss
376+ more_loss ["rmse_pf" ] = self .display_if_exist (
377+ xp .sqrt (l2_pref_force_loss ), find_atom_pref
378+ )
379+ elif self .loss_func == "mae" :
380+ l1_pref_force_loss = xp .mean (
381+ xp .multiply (xp .abs (diff_f ), atom_pref_reshape )
382+ )
383+ loss += pref_pf * l1_pref_force_loss
384+ more_loss ["mae_pf" ] = self .display_if_exist (
385+ l1_pref_force_loss , find_atom_pref
386+ )
387+ else :
388+ raise NotImplementedError (
389+ f"Loss type { self .loss_func } is not implemented for atom prefactor force loss."
390+ )
244391 if self .has_gf :
245392 find_drdq = label_dict ["find_drdq" ]
246393 drdq = label_dict ["drdq" ]
@@ -372,6 +519,8 @@ def serialize(self) -> dict:
372519 "numb_generalized_coord" : self .numb_generalized_coord ,
373520 "use_huber" : self .use_huber ,
374521 "huber_delta" : self .huber_delta ,
522+ "loss_func" : self .loss_func ,
523+ "f_use_norm" : self .f_use_norm ,
375524 }
376525
377526 @classmethod
0 commit comments