1414from deepmd .utils .data import (
1515 DataRequirementItem ,
1616)
17+ from deepmd .utils .loss import (
18+ resolve_huber_deltas ,
19+ )
1720from deepmd .utils .version import (
1821 check_version_compatibility ,
1922)
@@ -74,8 +77,10 @@ class EnergyLoss(Loss):
7477 - For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
7578 - For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
7679 Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
77- huber_delta : float
78- The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
80+ huber_delta : float | list[float]
81+ The threshold delta (D) used for Huber loss, controlling transition between
82+ L2 and L1 loss. It can be either one float shared by all terms or a list of
83+ three values ordered as [energy, force, virial].
7984 loss_func : str
8085 Loss function type for energy, force, and virial terms.
8186 Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
@@ -84,6 +89,7 @@ class EnergyLoss(Loss):
8489 f_use_norm : bool
8590 If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
8691 Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
92+ This treats the force vector as a whole rather than three independent components.
8793 **kwargs
8894 Other keyword arguments.
8995 """
@@ -107,7 +113,7 @@ def __init__(
107113 limit_pref_gf : float = 0.0 ,
108114 numb_generalized_coord : int = 0 ,
109115 use_huber : bool = False ,
110- huber_delta : float = 0.01 ,
116+ huber_delta : float | list [ float ] = 0.01 ,
111117 loss_func : str = "mse" ,
112118 f_use_norm : bool = False ,
113119 ** kwargs : Any ,
@@ -153,6 +159,11 @@ def __init__(
153159 raise RuntimeError (
154160 "f_use_norm can only be True when use_huber or loss_func='mae'."
155161 )
162+ (
163+ self ._huber_delta_energy ,
164+ self ._huber_delta_force ,
165+ self ._huber_delta_virial ,
166+ ) = resolve_huber_deltas (huber_delta )
156167 if self .use_huber and (
157168 self .has_pf or self .has_gf or self .relative_f is not None
158169 ):
@@ -215,7 +226,10 @@ def call(
215226
216227 if self .relative_f is not None :
217228 force_hat_3 = xp .reshape (force_hat , (- 1 , 3 ))
218- norm_f = xp .reshape (xp .norm (force_hat_3 , axis = 1 ), (- 1 , 1 )) + self .relative_f
229+ norm_f = (
230+ xp .reshape (xp .linalg .vector_norm (force_hat_3 , axis = 1 ), (- 1 , 1 ))
231+ + self .relative_f
232+ )
219233 diff_f_3 = xp .reshape (diff_f , (- 1 , 3 ))
220234 diff_f_3 = diff_f_3 / norm_f
221235 diff_f = xp .reshape (diff_f_3 , (- 1 ,))
@@ -250,7 +264,7 @@ def call(
250264 l_huber_loss = custom_huber_loss (
251265 atom_norm_ener * energy ,
252266 atom_norm_ener * energy_hat ,
253- delta = self .huber_delta ,
267+ delta = self ._huber_delta_energy ,
254268 )
255269 loss += pref_e * l_huber_loss
256270 more_loss ["rmse_e" ] = self .display_if_exist (
@@ -276,7 +290,7 @@ def call(
276290 l_huber_loss = custom_huber_loss (
277291 xp .reshape (force , (- 1 ,)),
278292 xp .reshape (force_hat , (- 1 ,)),
279- delta = self .huber_delta ,
293+ delta = self ._huber_delta_force ,
280294 )
281295 else :
282296 force_diff_3 = xp .reshape (force_hat - force , (- 1 , 3 ))
@@ -286,7 +300,7 @@ def call(
286300 l_huber_loss = custom_huber_loss (
287301 force_diff_norm ,
288302 xp .zeros_like (force_diff_norm ),
289- delta = self .huber_delta ,
303+ delta = self ._huber_delta_force ,
290304 )
291305 loss += pref_f * l_huber_loss
292306 more_loss ["rmse_f" ] = self .display_if_exist (
@@ -317,7 +331,7 @@ def call(
317331 l_huber_loss = custom_huber_loss (
318332 atom_norm * virial_reshape ,
319333 atom_norm * virial_hat_reshape ,
320- delta = self .huber_delta ,
334+ delta = self ._huber_delta_virial ,
321335 )
322336 loss += pref_v * l_huber_loss
323337 more_loss ["rmse_v" ] = self .display_if_exist (
@@ -336,7 +350,6 @@ def call(
336350 if self .has_ae :
337351 atom_ener_reshape = xp .reshape (atom_ener , (- 1 ,))
338352 atom_ener_hat_reshape = xp .reshape (atom_ener_hat , (- 1 ,))
339-
340353 if self .loss_func == "mse" :
341354 l2_atom_ener_loss = xp .mean (
342355 xp .square (atom_ener_hat_reshape - atom_ener_reshape ),
@@ -347,7 +360,7 @@ def call(
347360 l_huber_loss = custom_huber_loss (
348361 atom_ener_reshape ,
349362 atom_ener_hat_reshape ,
350- delta = self .huber_delta ,
363+ delta = self ._huber_delta_energy ,
351364 )
352365 loss += pref_ae * l_huber_loss
353366 more_loss ["rmse_ae" ] = self .display_if_exist (
0 commit comments