Skip to content

Commit dbc652e

Browse files
committed
feat(loss): support three-value huber delta
Allow huber_delta to accept either one shared float or a three-value list for energy, force and virial.
1 parent 3f52fa9 commit dbc652e

6 files changed

Lines changed: 116 additions & 34 deletions

File tree

deepmd/dpmodel/loss/ener.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from deepmd.utils.data import (
1515
DataRequirementItem,
1616
)
17+
from deepmd.utils.loss import (
18+
resolve_huber_deltas,
19+
)
1720
from deepmd.utils.version import (
1821
check_version_compatibility,
1922
)
@@ -74,8 +77,10 @@ class EnergyLoss(Loss):
7477
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
7578
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
7679
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
77-
huber_delta : float
78-
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
80+
huber_delta : float | list[float]
81+
The threshold delta (D) used for Huber loss, controlling transition between
82+
L2 and L1 loss. It can be either one float shared by all terms or a list of
83+
three values ordered as [energy, force, virial].
7984
loss_func : str
8085
Loss function type for energy, force, and virial terms.
8186
Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss).
@@ -84,6 +89,7 @@ class EnergyLoss(Loss):
8489
f_use_norm : bool
8590
If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
8691
Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
92+
This treats the force vector as a whole rather than three independent components.
8793
**kwargs
8894
Other keyword arguments.
8995
"""
@@ -107,7 +113,7 @@ def __init__(
107113
limit_pref_gf: float = 0.0,
108114
numb_generalized_coord: int = 0,
109115
use_huber: bool = False,
110-
huber_delta: float = 0.01,
116+
huber_delta: float | list[float] = 0.01,
111117
loss_func: str = "mse",
112118
f_use_norm: bool = False,
113119
**kwargs: Any,
@@ -153,6 +159,11 @@ def __init__(
153159
raise RuntimeError(
154160
"f_use_norm can only be True when use_huber or loss_func='mae'."
155161
)
162+
(
163+
self._huber_delta_energy,
164+
self._huber_delta_force,
165+
self._huber_delta_virial,
166+
) = resolve_huber_deltas(huber_delta)
156167
if self.use_huber and (
157168
self.has_pf or self.has_gf or self.relative_f is not None
158169
):
@@ -215,7 +226,10 @@ def call(
215226

216227
if self.relative_f is not None:
217228
force_hat_3 = xp.reshape(force_hat, (-1, 3))
218-
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), (-1, 1)) + self.relative_f
229+
norm_f = (
230+
xp.reshape(xp.linalg.vector_norm(force_hat_3, axis=1), (-1, 1))
231+
+ self.relative_f
232+
)
219233
diff_f_3 = xp.reshape(diff_f, (-1, 3))
220234
diff_f_3 = diff_f_3 / norm_f
221235
diff_f = xp.reshape(diff_f_3, (-1,))
@@ -250,7 +264,7 @@ def call(
250264
l_huber_loss = custom_huber_loss(
251265
atom_norm_ener * energy,
252266
atom_norm_ener * energy_hat,
253-
delta=self.huber_delta,
267+
delta=self._huber_delta_energy,
254268
)
255269
loss += pref_e * l_huber_loss
256270
more_loss["rmse_e"] = self.display_if_exist(
@@ -276,7 +290,7 @@ def call(
276290
l_huber_loss = custom_huber_loss(
277291
xp.reshape(force, (-1,)),
278292
xp.reshape(force_hat, (-1,)),
279-
delta=self.huber_delta,
293+
delta=self._huber_delta_force,
280294
)
281295
else:
282296
force_diff_3 = xp.reshape(force_hat - force, (-1, 3))
@@ -286,7 +300,7 @@ def call(
286300
l_huber_loss = custom_huber_loss(
287301
force_diff_norm,
288302
xp.zeros_like(force_diff_norm),
289-
delta=self.huber_delta,
303+
delta=self._huber_delta_force,
290304
)
291305
loss += pref_f * l_huber_loss
292306
more_loss["rmse_f"] = self.display_if_exist(
@@ -317,7 +331,7 @@ def call(
317331
l_huber_loss = custom_huber_loss(
318332
atom_norm * virial_reshape,
319333
atom_norm * virial_hat_reshape,
320-
delta=self.huber_delta,
334+
delta=self._huber_delta_virial,
321335
)
322336
loss += pref_v * l_huber_loss
323337
more_loss["rmse_v"] = self.display_if_exist(
@@ -336,7 +350,6 @@ def call(
336350
if self.has_ae:
337351
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
338352
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
339-
340353
if self.loss_func == "mse":
341354
l2_atom_ener_loss = xp.mean(
342355
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
@@ -347,7 +360,7 @@ def call(
347360
l_huber_loss = custom_huber_loss(
348361
atom_ener_reshape,
349362
atom_ener_hat_reshape,
350-
delta=self.huber_delta,
363+
delta=self._huber_delta_energy,
351364
)
352365
loss += pref_ae * l_huber_loss
353366
more_loss["rmse_ae"] = self.display_if_exist(

deepmd/pd/loss/ener.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from deepmd.utils.data import (
1919
DataRequirementItem,
2020
)
21+
from deepmd.utils.loss import (
22+
resolve_huber_deltas,
23+
)
2124
from deepmd.utils.version import (
2225
check_version_compatibility,
2326
)
@@ -56,7 +59,7 @@ def __init__(
5659
loss_func: str = "mse",
5760
inference: bool = False,
5861
use_huber: bool = False,
59-
huber_delta: float = 0.01,
62+
huber_delta: float | list[float] = 0.01,
6063
f_use_norm: bool = False,
6164
**kwargs: Any,
6265
) -> None:
@@ -109,8 +112,10 @@ def __init__(
109112
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
110113
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
111114
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
112-
huber_delta : float
113-
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
115+
huber_delta : float | list[float]
116+
The threshold delta (D) used for Huber loss, controlling transition between
117+
L2 and L1 loss. It can be either one float shared by all terms or a list of
118+
three values ordered as [energy, force, virial].
114119
f_use_norm : bool
115120
If True, use L2 norm of force vectors for loss calculation.
116121
Not implemented in PD backend, only for serialization compatibility.
@@ -156,6 +161,11 @@ def __init__(
156161
self.inference = inference
157162
self.use_huber = use_huber
158163
self.huber_delta = huber_delta
164+
(
165+
self._huber_delta_energy,
166+
self._huber_delta_force,
167+
self._huber_delta_virial,
168+
) = resolve_huber_deltas(huber_delta)
159169
if self.use_huber and (
160170
self.has_pf or self.has_gf or self.relative_f is not None
161171
):
@@ -238,7 +248,7 @@ def forward(
238248
l_huber_loss = custom_huber_loss(
239249
atom_norm * energy_pred,
240250
atom_norm * energy_label,
241-
delta=self.huber_delta,
251+
delta=self._huber_delta_energy,
242252
)
243253
loss += pref_e * l_huber_loss
244254
rmse_e = l2_ener_loss.sqrt() * atom_norm
@@ -305,7 +315,7 @@ def forward(
305315
l_huber_loss = custom_huber_loss(
306316
force_pred.reshape([-1]),
307317
force_label.reshape([-1]),
308-
delta=self.huber_delta,
318+
delta=self._huber_delta_force,
309319
)
310320
loss += pref_f * l_huber_loss
311321
rmse_f = l2_force_loss.sqrt()
@@ -409,7 +419,7 @@ def forward(
409419
l_huber_loss = custom_huber_loss(
410420
atom_norm * model_pred["virial"].reshape([-1]),
411421
atom_norm * label["virial"].reshape([-1]),
412-
delta=self.huber_delta,
422+
delta=self._huber_delta_virial,
413423
)
414424
loss += pref_v * l_huber_loss
415425
rmse_v = l2_virial_loss.sqrt() * atom_norm
@@ -440,7 +450,7 @@ def forward(
440450
l_huber_loss = custom_huber_loss(
441451
atom_ener_reshape,
442452
atom_ener_label_reshape,
443-
delta=self.huber_delta,
453+
delta=self._huber_delta_energy,
444454
)
445455
loss += pref_ae * l_huber_loss
446456
rmse_ae = l2_atom_ener_loss.sqrt()

deepmd/pt/loss/ener.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from deepmd.utils.data import (
1919
DataRequirementItem,
2020
)
21+
from deepmd.utils.loss import (
22+
resolve_huber_deltas,
23+
)
2124
from deepmd.utils.version import (
2225
check_version_compatibility,
2326
)
@@ -57,7 +60,7 @@ def __init__(
5760
inference: bool = False,
5861
use_huber: bool = False,
5962
f_use_norm: bool = False,
60-
huber_delta: float = 0.01,
63+
huber_delta: float | list[float] = 0.01,
6164
**kwargs: Any,
6265
) -> None:
6366
r"""Construct a layer to compute loss on energy, force and virial.
@@ -112,8 +115,11 @@ def __init__(
112115
f_use_norm : bool
113116
If true, use L2 norm of force vectors for loss calculation when loss_func='mae' or use_huber is True.
114117
Instead of computing loss on force components, computes loss on ||F_pred - F_label||_2.
115-
huber_delta : float
116-
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
118+
This treats the force vector as a whole rather than three independent components.
119+
huber_delta : float | list[float]
120+
The threshold delta (D) used for Huber loss, controlling transition between
121+
L2 and L1 loss. It can be either one float shared by all terms or a list of
122+
three values ordered as [energy, force, virial].
117123
**kwargs
118124
Other keyword arguments.
119125
"""
@@ -162,6 +168,11 @@ def __init__(
162168
"f_use_norm can only be True when use_huber or loss_func='mae'."
163169
)
164170
self.huber_delta = huber_delta
171+
(
172+
self._huber_delta_energy,
173+
self._huber_delta_force,
174+
self._huber_delta_virial,
175+
) = resolve_huber_deltas(huber_delta)
165176
if self.use_huber and (
166177
self.has_pf or self.has_gf or self.relative_f is not None
167178
):
@@ -244,7 +255,7 @@ def forward(
244255
l_huber_loss = custom_huber_loss(
245256
atom_norm * energy_pred,
246257
atom_norm * energy_label,
247-
delta=self.huber_delta,
258+
delta=self._huber_delta_energy,
248259
)
249260
loss += pref_e * l_huber_loss
250261
rmse_e = l2_ener_loss.sqrt() * atom_norm
@@ -308,7 +319,7 @@ def forward(
308319
l_huber_loss = custom_huber_loss(
309320
force_pred.reshape(-1),
310321
force_label.reshape(-1),
311-
delta=self.huber_delta,
322+
delta=self._huber_delta_force,
312323
)
313324
else:
314325
force_diff_norm = torch.linalg.vector_norm(
@@ -320,7 +331,7 @@ def forward(
320331
l_huber_loss = custom_huber_loss(
321332
force_diff_norm,
322333
torch.zeros_like(force_diff_norm),
323-
delta=self.huber_delta,
334+
delta=self._huber_delta_force,
324335
)
325336
loss += pref_f * l_huber_loss
326337
rmse_f = l2_force_loss.sqrt()
@@ -426,7 +437,7 @@ def forward(
426437
l_huber_loss = custom_huber_loss(
427438
atom_norm * model_pred["virial"].reshape(-1),
428439
atom_norm * label["virial"].reshape(-1),
429-
delta=self.huber_delta,
440+
delta=self._huber_delta_virial,
430441
)
431442
loss += pref_v * l_huber_loss
432443
rmse_v = l2_virial_loss.sqrt() * atom_norm
@@ -474,7 +485,7 @@ def forward(
474485
l_huber_loss = custom_huber_loss(
475486
atom_ener_reshape,
476487
atom_ener_label_reshape,
477-
delta=self.huber_delta,
488+
delta=self._huber_delta_energy,
478489
)
479490
loss += pref_ae * l_huber_loss
480491
rmse_ae = l2_atom_ener_loss.sqrt()

deepmd/tf/loss/ener.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from deepmd.utils.data import (
1818
DataRequirementItem,
1919
)
20+
from deepmd.utils.loss import (
21+
resolve_huber_deltas,
22+
)
2023
from deepmd.utils.version import (
2124
check_version_compatibility,
2225
)
@@ -87,8 +90,10 @@ class EnerStdLoss(Loss):
8790
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
8891
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
8992
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
90-
huber_delta : float
91-
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
93+
huber_delta : float | list[float]
94+
The threshold delta (D) used for Huber loss, controlling transition between
95+
L2 and L1 loss. It can be either one float shared by all terms or a list of
96+
three values ordered as [energy, force, virial].
9297
loss_func : str
9398
Loss function type. Options: 'mse' or 'mae'.
9499
Not implemented in TF backend, only for serialization compatibility.
@@ -118,7 +123,7 @@ def __init__(
118123
limit_pref_gf: float = 0.0,
119124
numb_generalized_coord: int = 0,
120125
use_huber: bool = False,
121-
huber_delta: float = 0.01,
126+
huber_delta: float | list[float] = 0.01,
122127
loss_func: str = "mse",
123128
f_use_norm: bool = False,
124129
**kwargs: Any,
@@ -162,6 +167,11 @@ def __init__(
162167
)
163168
self.use_huber = use_huber
164169
self.huber_delta = huber_delta
170+
(
171+
self._huber_delta_energy,
172+
self._huber_delta_force,
173+
self._huber_delta_virial,
174+
) = resolve_huber_deltas(huber_delta)
165175
if self.use_huber and (
166176
self.has_pf or self.has_gf or self.relative_f is not None
167177
):
@@ -351,7 +361,7 @@ def build(
351361
l_huber_loss = custom_huber_loss(
352362
atom_norm_ener * energy,
353363
atom_norm_ener * energy_hat,
354-
delta=self.huber_delta,
364+
delta=self._huber_delta_energy,
355365
)
356366
loss += pref_e * l_huber_loss
357367
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy)
@@ -362,7 +372,7 @@ def build(
362372
l_huber_loss = custom_huber_loss(
363373
tf.reshape(force, [-1]),
364374
tf.reshape(force_hat, [-1]),
365-
delta=self.huber_delta,
375+
delta=self._huber_delta_force,
366376
)
367377
loss += pref_f * l_huber_loss
368378
more_loss["l2_force_loss"] = self.display_if_exist(
@@ -375,7 +385,7 @@ def build(
375385
l_huber_loss = custom_huber_loss(
376386
atom_norm * tf.reshape(virial, [-1]),
377387
atom_norm * tf.reshape(virial_hat, [-1]),
378-
delta=self.huber_delta,
388+
delta=self._huber_delta_virial,
379389
)
380390
loss += pref_v * l_huber_loss
381391
more_loss["l2_virial_loss"] = self.display_if_exist(
@@ -388,7 +398,7 @@ def build(
388398
l_huber_loss = custom_huber_loss(
389399
tf.reshape(atom_ener, [-1]),
390400
tf.reshape(atom_ener_hat, [-1]),
391-
delta=self.huber_delta,
401+
delta=self._huber_delta_energy,
392402
)
393403
loss += pref_ae * l_huber_loss
394404
more_loss["l2_atom_ener_loss"] = self.display_if_exist(

deepmd/utils/argcheck.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,7 +3108,11 @@ def loss_ener() -> list[Argument]:
31083108
"- For absolute errors exceeding D: linear loss D * (\\|error\\| - 0.5 * D) \n\n"
31093109
"Formula: loss = 0.5 * (error**2) if \\|error\\| <= D else D * (\\|error\\| - 0.5 * D). "
31103110
)
3111-
doc_huber_delta = "The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss. "
3111+
doc_huber_delta = (
3112+
"The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss. "
3113+
"It can be either one float shared by all terms or a list of "
3114+
"three values ordered as [energy, force, virial]. "
3115+
)
31123116
doc_loss_func = (
31133117
"Loss function type for energy, force, and virial terms. "
31143118
"Options: 'mse' (Mean Squared Error, L2 loss, default) or 'mae' (Mean Absolute Error, L1 loss). "
@@ -3258,7 +3262,7 @@ def loss_ener() -> list[Argument]:
32583262
),
32593263
Argument(
32603264
"huber_delta",
3261-
float,
3265+
[float, list[float]],
32623266
optional=True,
32633267
default=0.01,
32643268
doc=doc_huber_delta,

0 commit comments

Comments
 (0)