Skip to content

Commit cea6899

Browse files
committed
Update ener.py
1 parent ab95efe commit cea6899

1 file changed

Lines changed: 58 additions & 57 deletions

File tree

deepmd/dpmodel/loss/ener.py

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,64 @@ def custom_huber_loss(predictions: Array, targets: Array, delta: float = 1.0) ->
3030

3131

3232
class EnergyLoss(Loss):
33+
r"""Construct a layer to compute loss on energy, force and virial.
34+
35+
Parameters
36+
----------
37+
starter_learning_rate : float
38+
The learning rate at the start of the training.
39+
start_pref_e : float
40+
The prefactor of energy loss at the start of the training.
41+
limit_pref_e : float
42+
The prefactor of energy loss at the end of the training.
43+
start_pref_f : float
44+
The prefactor of force loss at the start of the training.
45+
limit_pref_f : float
46+
The prefactor of force loss at the end of the training.
47+
start_pref_v : float
48+
The prefactor of virial loss at the start of the training.
49+
limit_pref_v : float
50+
The prefactor of virial loss at the end of the training.
51+
start_pref_ae : float
52+
The prefactor of atomic energy loss at the start of the training.
53+
limit_pref_ae : float
54+
The prefactor of atomic energy loss at the end of the training.
55+
start_pref_pf : float
56+
The prefactor of atomic prefactor force loss at the start of the training.
57+
limit_pref_pf : float
58+
The prefactor of atomic prefactor force loss at the end of the training.
59+
relative_f : float
60+
If provided, relative force error will be used in the loss. The difference
61+
of force will be normalized by the magnitude of the force in the label with
62+
a shift given by relative_f
63+
enable_atom_ener_coeff : bool
64+
if true, the energy will be computed as \sum_i c_i E_i
65+
start_pref_gf : float
66+
The prefactor of generalized force loss at the start of the training.
67+
limit_pref_gf : float
68+
The prefactor of generalized force loss at the end of the training.
69+
numb_generalized_coord : int
70+
The dimension of generalized coordinates.
71+
use_huber : bool
72+
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
73+
The loss function smoothly transitions between L2 and L1 loss:
74+
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
75+
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
76+
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
77+
huber_delta : float
78+
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
79+
loss_func : str
80+
Loss function type for energy, force, and virial terms.
81+
Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
82+
MAE loss is less sensitive to outliers compared to MSE loss.
83+
Future extensions may support additional loss types.
84+
f_use_norm : bool
85+
If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
86+
Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
87+
**kwargs
88+
Other keyword arguments.
89+
"""
90+
3391
def __init__(
3492
self,
3593
starter_learning_rate: float,
@@ -54,63 +112,6 @@ def __init__(
54112
f_use_norm: bool = False,
55113
**kwargs: Any,
56114
) -> None:
57-
r"""Construct a layer to compute loss on energy, force and virial.
58-
59-
Parameters
60-
----------
61-
starter_learning_rate : float
62-
The learning rate at the start of the training.
63-
start_pref_e : float
64-
The prefactor of energy loss at the start of the training.
65-
limit_pref_e : float
66-
The prefactor of energy loss at the end of the training.
67-
start_pref_f : float
68-
The prefactor of force loss at the start of the training.
69-
limit_pref_f : float
70-
The prefactor of force loss at the end of the training.
71-
start_pref_v : float
72-
The prefactor of virial loss at the start of the training.
73-
limit_pref_v : float
74-
The prefactor of virial loss at the end of the training.
75-
start_pref_ae : float
76-
The prefactor of atomic energy loss at the start of the training.
77-
limit_pref_ae : float
78-
The prefactor of atomic energy loss at the end of the training.
79-
start_pref_pf : float
80-
The prefactor of atomic prefactor force loss at the start of the training.
81-
limit_pref_pf : float
82-
The prefactor of atomic prefactor force loss at the end of the training.
83-
relative_f : float
84-
If provided, relative force error will be used in the loss. The difference
85-
of force will be normalized by the magnitude of the force in the label with
86-
a shift given by relative_f
87-
enable_atom_ener_coeff : bool
88-
if true, the energy will be computed as \sum_i c_i E_i
89-
start_pref_gf : float
90-
The prefactor of generalized force loss at the start of the training.
91-
limit_pref_gf : float
92-
The prefactor of generalized force loss at the end of the training.
93-
numb_generalized_coord : int
94-
The dimension of generalized coordinates.
95-
use_huber : bool
96-
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
97-
The loss function smoothly transitions between L2 and L1 loss:
98-
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
99-
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
100-
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
101-
huber_delta : float
102-
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
103-
loss_func : str
104-
Loss function type for energy, force, and virial terms.
105-
Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
106-
MAE loss is less sensitive to outliers compared to MSE loss.
107-
Future extensions may support additional loss types.
108-
f_use_norm : bool
109-
If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
110-
Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
111-
**kwargs
112-
Other keyword arguments.
113-
"""
114115
# Validate loss_func
115116
valid_loss_funcs = ["mse", "mae"]
116117
if loss_func not in valid_loss_funcs:

0 commit comments

Comments
 (0)