@@ -50,10 +50,75 @@ def __init__(
5050 numb_generalized_coord : int = 0 ,
5151 use_huber : bool = False ,
5252 huber_delta : float = 0.01 ,
53- use_mae_loss : bool = False ,
53+ loss_func : str = "mse" ,
5454 f_use_norm : bool = False ,
5555 ** kwargs : Any ,
5656 ) -> None :
57+ r"""Construct a layer to compute loss on energy, force and virial.
58+
59+ Parameters
60+ ----------
61+ starter_learning_rate : float
62+ The learning rate at the start of the training.
63+ start_pref_e : float
64+ The prefactor of energy loss at the start of the training.
65+ limit_pref_e : float
66+ The prefactor of energy loss at the end of the training.
67+ start_pref_f : float
68+ The prefactor of force loss at the start of the training.
69+ limit_pref_f : float
70+ The prefactor of force loss at the end of the training.
71+ start_pref_v : float
72+ The prefactor of virial loss at the start of the training.
73+ limit_pref_v : float
74+ The prefactor of virial loss at the end of the training.
75+ start_pref_ae : float
76+ The prefactor of atomic energy loss at the start of the training.
77+ limit_pref_ae : float
78+ The prefactor of atomic energy loss at the end of the training.
79+ start_pref_pf : float
80+ The prefactor of atomic prefactor force loss at the start of the training.
81+ limit_pref_pf : float
82+ The prefactor of atomic prefactor force loss at the end of the training.
83+ relative_f : float
84+ If provided, relative force error will be used in the loss. The difference
85+ of force will be normalized by the magnitude of the force in the label with
86+ a shift given by relative_f
87+ enable_atom_ener_coeff : bool
88+ if true, the energy will be computed as \sum_i c_i E_i
89+ start_pref_gf : float
90+ The prefactor of generalized force loss at the start of the training.
91+ limit_pref_gf : float
92+ The prefactor of generalized force loss at the end of the training.
93+ numb_generalized_coord : int
94+ The dimension of generalized coordinates.
95+ use_huber : bool
96+ Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
97+ The loss function smoothly transitions between L2 and L1 loss:
98+ - For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
99+ - For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
100+ Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
101+ huber_delta : float
102+ The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
103+ loss_func : str
104+ Loss function type for energy, force, and virial terms.
105+ Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
106+ MAE loss is less sensitive to outliers compared to MSE loss.
107+ Future extensions may support additional loss types.
108+ f_use_norm : bool
109+ If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
110+ Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
111+ **kwargs
112+ Other keyword arguments.
113+ """
114+ # Validate loss_func
115+ valid_loss_funcs = ["mse" , "mae" ]
116+ if loss_func not in valid_loss_funcs :
117+ raise ValueError (
118+ f"Invalid loss_func '{ loss_func } '. Must be one of { valid_loss_funcs } ."
119+ )
120+
121+ self .loss_func = loss_func
57122 self .starter_learning_rate = starter_learning_rate
58123 self .start_pref_e = start_pref_e
59124 self .limit_pref_e = limit_pref_e
@@ -82,11 +147,10 @@ def __init__(
82147 )
83148 self .use_huber = use_huber
84149 self .huber_delta = huber_delta
85- self .use_mae_loss = use_mae_loss
86150 self .f_use_norm = f_use_norm
87- if self .f_use_norm and not (self .use_huber or self .use_mae_loss ):
151+ if self .f_use_norm and not (self .use_huber or self .loss_func == "mae" ):
88152 raise RuntimeError (
89- "f_use_norm can only be True when use_huber or use_mae_loss is True ."
153+ "f_use_norm can only be True when use_huber or loss_func='mae' ."
90154 )
91155 if self .use_huber and (
92156 self .has_pf or self .has_gf or self .relative_f is not None
@@ -177,7 +241,7 @@ def call(
177241 loss = 0
178242 more_loss = {}
179243 if self .has_e :
180- if not self .use_mae_loss :
244+ if self .loss_func == "mse" :
181245 l2_ener_loss = xp .mean (xp .square (energy - energy_hat ))
182246 if not self .use_huber :
183247 loss += atom_norm_ener * (pref_e * l2_ener_loss )
@@ -191,14 +255,18 @@ def call(
191255 more_loss ["rmse_e" ] = self .display_if_exist (
192256 xp .sqrt (l2_ener_loss ) * atom_norm_ener , find_energy
193257 )
194- else :
258+ elif self . loss_func == "mae" :
195259 l1_ener_loss = xp .mean (xp .abs (energy - energy_hat ))
196260 loss += atom_norm_ener * (pref_e * l1_ener_loss )
197261 more_loss ["mae_e" ] = self .display_if_exist (
198262 l1_ener_loss * atom_norm_ener , find_energy
199263 )
264+ else :
265+ raise NotImplementedError (
266+ f"Loss type { self .loss_func } is not implemented for energy loss."
267+ )
200268 if self .has_f :
201- if not self .use_mae_loss :
269+ if self .loss_func == "mse" :
202270 l2_force_loss = xp .mean (xp .square (diff_f ))
203271 if not self .use_huber :
204272 loss += pref_f * l2_force_loss
@@ -223,18 +291,22 @@ def call(
223291 more_loss ["rmse_f" ] = self .display_if_exist (
224292 xp .sqrt (l2_force_loss ), find_force
225293 )
226- else :
294+ elif self . loss_func == "mae" :
227295 if not self .f_use_norm :
228296 l1_force_loss = xp .mean (xp .abs (diff_f ))
229297 else :
230298 force_diff_3 = xp .reshape (force_hat - force , (- 1 , 3 ))
231299 l1_force_loss = xp .mean (xp .linalg .vector_norm (force_diff_3 , axis = 1 ))
232300 loss += pref_f * l1_force_loss
233301 more_loss ["mae_f" ] = self .display_if_exist (l1_force_loss , find_force )
302+ else :
303+ raise NotImplementedError (
304+ f"Loss type { self .loss_func } is not implemented for force loss."
305+ )
234306 if self .has_v :
235307 virial_reshape = xp .reshape (virial , (- 1 ,))
236308 virial_hat_reshape = xp .reshape (virial_hat , (- 1 ,))
237- if not self .use_mae_loss :
309+ if self .loss_func == "mse" :
238310 l2_virial_loss = xp .mean (
239311 xp .square (virial_hat_reshape - virial_reshape ),
240312 )
@@ -250,39 +322,71 @@ def call(
250322 more_loss ["rmse_v" ] = self .display_if_exist (
251323 xp .sqrt (l2_virial_loss ) * atom_norm , find_virial
252324 )
253- else :
325+ elif self . loss_func == "mae" :
254326 l1_virial_loss = xp .mean (xp .abs (virial_hat_reshape - virial_reshape ))
255327 loss += atom_norm * (pref_v * l1_virial_loss )
256328 more_loss ["mae_v" ] = self .display_if_exist (
257329 l1_virial_loss * atom_norm , find_virial
258330 )
331+ else :
332+ raise NotImplementedError (
333+ f"Loss type { self .loss_func } is not implemented for virial loss."
334+ )
259335 if self .has_ae :
260336 atom_ener_reshape = xp .reshape (atom_ener , (- 1 ,))
261337 atom_ener_hat_reshape = xp .reshape (atom_ener_hat , (- 1 ,))
262- l2_atom_ener_loss = xp .mean (
263- xp .square (atom_ener_hat_reshape - atom_ener_reshape ),
264- )
265- if not self .use_huber :
266- loss += pref_ae * l2_atom_ener_loss
338+
339+ if self .loss_func == "mse" :
340+ l2_atom_ener_loss = xp .mean (
341+ xp .square (atom_ener_hat_reshape - atom_ener_reshape ),
342+ )
343+ if not self .use_huber :
344+ loss += pref_ae * l2_atom_ener_loss
345+ else :
346+ l_huber_loss = custom_huber_loss (
347+ atom_ener_reshape ,
348+ atom_ener_hat_reshape ,
349+ delta = self .huber_delta ,
350+ )
351+ loss += pref_ae * l_huber_loss
352+ more_loss ["rmse_ae" ] = self .display_if_exist (
353+ xp .sqrt (l2_atom_ener_loss ), find_atom_ener
354+ )
355+ elif self .loss_func == "mae" :
356+ l1_atom_ener_loss = xp .mean (
357+ xp .abs (atom_ener_hat_reshape - atom_ener_reshape )
358+ )
359+ loss += pref_ae * l1_atom_ener_loss
360+ more_loss ["mae_ae" ] = self .display_if_exist (
361+ l1_atom_ener_loss , find_atom_ener
362+ )
267363 else :
268- l_huber_loss = custom_huber_loss (
269- atom_ener_reshape ,
270- atom_ener_hat_reshape ,
271- delta = self .huber_delta ,
364+ raise NotImplementedError (
365+ f"Loss type { self .loss_func } is not implemented for atomic energy loss."
272366 )
273- loss += pref_ae * l_huber_loss
274- more_loss ["rmse_ae" ] = self .display_if_exist (
275- xp .sqrt (l2_atom_ener_loss ), find_atom_ener
276- )
277367 if self .has_pf :
278368 atom_pref_reshape = xp .reshape (atom_pref , (- 1 ,))
279- l2_pref_force_loss = xp .mean (
280- xp .multiply (xp .square (diff_f ), atom_pref_reshape ),
281- )
282- loss += pref_pf * l2_pref_force_loss
283- more_loss ["rmse_pf" ] = self .display_if_exist (
284- xp .sqrt (l2_pref_force_loss ), find_atom_pref
285- )
369+
370+ if self .loss_func == "mse" :
371+ l2_pref_force_loss = xp .mean (
372+ xp .multiply (xp .square (diff_f ), atom_pref_reshape ),
373+ )
374+ loss += pref_pf * l2_pref_force_loss
375+ more_loss ["rmse_pf" ] = self .display_if_exist (
376+ xp .sqrt (l2_pref_force_loss ), find_atom_pref
377+ )
378+ elif self .loss_func == "mae" :
379+ l1_pref_force_loss = xp .mean (
380+ xp .multiply (xp .abs (diff_f ), atom_pref_reshape )
381+ )
382+ loss += pref_pf * l1_pref_force_loss
383+ more_loss ["mae_pf" ] = self .display_if_exist (
384+ l1_pref_force_loss , find_atom_pref
385+ )
386+ else :
387+ raise NotImplementedError (
388+ f"Loss type { self .loss_func } is not implemented for atom prefactor force loss."
389+ )
286390 if self .has_gf :
287391 find_drdq = label_dict ["find_drdq" ]
288392 drdq = label_dict ["drdq" ]
@@ -413,7 +517,7 @@ def serialize(self) -> dict:
413517 "numb_generalized_coord" : self .numb_generalized_coord ,
414518 "use_huber" : self .use_huber ,
415519 "huber_delta" : self .huber_delta ,
416- "use_mae_loss " : self .use_mae_loss ,
520+ "loss_func " : self .loss_func ,
417521 "f_use_norm" : self .f_use_norm ,
418522 }
419523
0 commit comments