@@ -30,6 +30,64 @@ def custom_huber_loss(predictions: Array, targets: Array, delta: float = 1.0) ->
3030
3131
3232class EnergyLoss (Loss ):
33+ r"""Construct a layer to compute loss on energy, force and virial.
34+
35+ Parameters
36+ ----------
37+ starter_learning_rate : float
38+ The learning rate at the start of the training.
39+ start_pref_e : float
40+ The prefactor of energy loss at the start of the training.
41+ limit_pref_e : float
42+ The prefactor of energy loss at the end of the training.
43+ start_pref_f : float
44+ The prefactor of force loss at the start of the training.
45+ limit_pref_f : float
46+ The prefactor of force loss at the end of the training.
47+ start_pref_v : float
48+ The prefactor of virial loss at the start of the training.
49+ limit_pref_v : float
50+ The prefactor of virial loss at the end of the training.
51+ start_pref_ae : float
52+ The prefactor of atomic energy loss at the start of the training.
53+ limit_pref_ae : float
54+ The prefactor of atomic energy loss at the end of the training.
55+ start_pref_pf : float
56+ The prefactor of atomic prefactor force loss at the start of the training.
57+ limit_pref_pf : float
58+ The prefactor of atomic prefactor force loss at the end of the training.
59+ relative_f : float
60+ If provided, relative force error will be used in the loss. The difference
61+ of force will be normalized by the magnitude of the force in the label with
62+ a shift given by relative_f
63+ enable_atom_ener_coeff : bool
64+ if true, the energy will be computed as \sum_i c_i E_i
65+ start_pref_gf : float
66+ The prefactor of generalized force loss at the start of the training.
67+ limit_pref_gf : float
68+ The prefactor of generalized force loss at the end of the training.
69+ numb_generalized_coord : int
70+ The dimension of generalized coordinates.
71+ use_huber : bool
72+ Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
73+ The loss function smoothly transitions between L2 and L1 loss:
74+ - For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
75+ - For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
76+ Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
77+ huber_delta : float
78+ The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
79+ loss_func : str
80+ Loss function type for energy, force, and virial terms.
81+ Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
82+ MAE loss is less sensitive to outliers compared to MSE loss.
83+ Future extensions may support additional loss types.
84+ f_use_norm : bool
85+ If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
86+ Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
87+ **kwargs
88+ Other keyword arguments.
89+ """
90+
3391 def __init__ (
3492 self ,
3593 starter_learning_rate : float ,
@@ -54,63 +112,6 @@ def __init__(
54112 f_use_norm : bool = False ,
55113 ** kwargs : Any ,
56114 ) -> None :
57- r"""Construct a layer to compute loss on energy, force and virial.
58-
59- Parameters
60- ----------
61- starter_learning_rate : float
62- The learning rate at the start of the training.
63- start_pref_e : float
64- The prefactor of energy loss at the start of the training.
65- limit_pref_e : float
66- The prefactor of energy loss at the end of the training.
67- start_pref_f : float
68- The prefactor of force loss at the start of the training.
69- limit_pref_f : float
70- The prefactor of force loss at the end of the training.
71- start_pref_v : float
72- The prefactor of virial loss at the start of the training.
73- limit_pref_v : float
74- The prefactor of virial loss at the end of the training.
75- start_pref_ae : float
76- The prefactor of atomic energy loss at the start of the training.
77- limit_pref_ae : float
78- The prefactor of atomic energy loss at the end of the training.
79- start_pref_pf : float
80- The prefactor of atomic prefactor force loss at the start of the training.
81- limit_pref_pf : float
82- The prefactor of atomic prefactor force loss at the end of the training.
83- relative_f : float
84- If provided, relative force error will be used in the loss. The difference
85- of force will be normalized by the magnitude of the force in the label with
86- a shift given by relative_f
87- enable_atom_ener_coeff : bool
88- if true, the energy will be computed as \sum_i c_i E_i
89- start_pref_gf : float
90- The prefactor of generalized force loss at the start of the training.
91- limit_pref_gf : float
92- The prefactor of generalized force loss at the end of the training.
93- numb_generalized_coord : int
94- The dimension of generalized coordinates.
95- use_huber : bool
96- Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
97- The loss function smoothly transitions between L2 and L1 loss:
98- - For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
99- - For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
100- Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
101- huber_delta : float
102- The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
103- loss_func : str
104- Loss function type for energy, force, and virial terms.
105- Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
106- MAE loss is less sensitive to outliers compared to MSE loss.
107- Future extensions may support additional loss types.
108- f_use_norm : bool
109- If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
110- Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
111- **kwargs
112- Other keyword arguments.
113- """
114115 # Validate loss_func
115116 valid_loss_funcs = ["mse" , "mae" ]
116117 if loss_func not in valid_loss_funcs :
0 commit comments