Skip to content

Commit dd0a752

Browse files
committed
add seperate delta for f/e; make delta2 root
1 parent a2833d0 commit dd0a752

2 files changed

Lines changed: 117 additions & 48 deletions

File tree

deepmd/pt/loss/ener.py

Lines changed: 109 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import (
44
Any,
55
Optional,
6+
Union,
67
)
78

89
import torch
@@ -29,21 +30,67 @@ def custom_huber_loss(
2930
predictions: torch.Tensor,
3031
targets: torch.Tensor,
3132
delta: float = 1.0,
32-
two_stage_delta: float = 1.0,
33-
c_sq: float = 0.0,
34-
k: float = 0.0,
33+
delta2: float = 1.0,
34+
k1: float = 0.0,
35+
k2: float = 0.0,
36+
use_root: bool = False,
3537
) -> torch.Tensor:
3638
error = targets - predictions
3739
abs_error = torch.abs(error)
3840
quadratic_loss = 0.5 * torch.pow(error, 2)
3941
linear_loss = delta * (abs_error - 0.5 * delta)
4042
loss = torch.where(abs_error <= delta, quadratic_loss, linear_loss)
41-
if two_stage_delta is not None and two_stage_delta > delta:
42-
cauchy_loss = 0.5 * c_sq * torch.log1p(abs_error**2 / c_sq) + k
43-
loss = torch.where(abs_error > two_stage_delta, cauchy_loss, loss)
43+
if delta2 is not None and delta2 > delta:
44+
if not use_root:
45+
stage2_loss = 0.5 * k1 * torch.log1p(abs_error**2 / k1) + k2
46+
else:
47+
stage2_loss = k1 * torch.sqrt(abs_error) + k2
48+
loss = torch.where(abs_error > delta2, stage2_loss, loss)
4449
return torch.mean(loss)
4550

4651

52+
def cauchy_params_from_huber(
53+
delta1: float,
54+
delta2: float,
55+
) -> tuple[float, float]:
56+
assert delta2 >= delta1
57+
# --- 保证 L1 -> Cauchy 连续性所需的参数计算 ---
58+
# 1. 根据 C1 连续性 (梯度匹配) 求解 Cauchy 尺度 c^2
59+
# 目标: dL/dr (L1) = delta1; dL/dr (Cauchy) = r / (1 + r^2/c^2)
60+
# 匹配点 r=delta2: delta1 = delta2 / (1 + delta2^2/c^2)
61+
c_sq = delta2**2 / (delta2 / delta1 - 1.0)
62+
# 2. 计算 L2/L1 在 delta2 处的值 (C_match)
63+
# C_match = delta1 * delta2 - 0.5 * delta1^2
64+
L_L1_at_delta2 = delta1 * delta2 - 0.5 * delta1**2
65+
# 3. 计算 Cauchy 在 delta2 处的值 (L_Cauchy)
66+
L_Cauchy_at_delta2 = 0.5 * c_sq * math.log(1.0 + delta2**2 / c_sq)
67+
# 4. 计算偏移量 K (K = C_match - L_Cauchy)
68+
# 保证 L3(r) = L_Cauchy(r) + K 在 delta2 处值连续
69+
k = L_L1_at_delta2 - L_Cauchy_at_delta2
70+
return c_sq, k
71+
72+
73+
def root_params_from_huber(
74+
delta1: float,
75+
delta2: float,
76+
) -> tuple[float, float]:
77+
assert delta2 >= delta1
78+
# 1. 计算 K1 (缩放系数) - 保证 C1 连续性 (梯度匹配)
79+
# 目标梯度 dL/dr (L1) = delta1
80+
# L0.5 梯度 dL/dr (L0.5) = 0.5 * K1 * r^(-0.5)
81+
# 在 r=delta2 匹配: delta1 = 0.5 * K1 * delta2^(-0.5)
82+
# 解得 K1 = 2 * delta1 * sqrt(delta2)
83+
K1 = 2.0 * delta1 * math.sqrt(delta2)
84+
# 2. 计算 L1 在 delta2 处的值 (L_L1)
85+
L_L1_at_delta2 = delta1 * delta2 - 0.5 * delta1**2
86+
# 3. 计算 K2 (偏移量) - 保证 C0 连续性 (值匹配)
87+
# L_L1(delta2) = K1 * sqrt(delta2) + K2
88+
# 解得 K2 = L_L1(delta2) - K1 * sqrt(delta2)
89+
K2 = L_L1_at_delta2 - K1 * math.sqrt(delta2)
90+
# 简化后: K2 = - delta1 * delta2 - 0.5 * delta1^2
91+
return K1, K2
92+
93+
4794
class EnergyStdLoss(TaskLoss):
4895
def __init__(
4996
self,
@@ -66,11 +113,12 @@ def __init__(
66113
use_l1_all: bool = False,
67114
inference: bool = False,
68115
use_huber: bool = False,
69-
huber_delta: float = 0.01,
70-
huber_two_stage_delta: Optional[float] = None,
116+
huber_delta: Union[float, list[float]] = 0.01,
117+
huber_two_stage_delta: Optional[Union[float, list[float]]] = None,
71118
trimmed_factor: float = 0.0,
72119
adaptive_loss: bool = False,
73120
learnable_pref: bool = False,
121+
huber_two_stage_use_root: bool = False,
74122
**kwargs: Any,
75123
) -> None:
76124
r"""Construct a layer to compute loss on energy, force and virial.
@@ -157,31 +205,33 @@ def __init__(
157205
self.use_l1_all = use_l1_all
158206
self.inference = inference
159207
self.use_huber = use_huber
160-
self.huber_delta = huber_delta
161-
self.huber_two_stage_delta = huber_two_stage_delta
208+
self.huber_delta = (
209+
[huber_delta] if isinstance(huber_delta, float) else huber_delta
210+
)
211+
self.huber_two_stage_delta = (
212+
[huber_two_stage_delta]
213+
if isinstance(huber_two_stage_delta, float)
214+
else huber_two_stage_delta
215+
)
216+
self.huber_two_stage_use_root = huber_two_stage_use_root
162217
if self.use_huber and self.huber_two_stage_delta is not None:
163-
assert huber_two_stage_delta >= huber_delta
164-
# --- 保证 L1 -> Cauchy 连续性所需的参数计算 ---
165-
# 1. 根据 C1 连续性 (梯度匹配) 求解 Cauchy 尺度 c^2
166-
# 目标: dL/dr (L1) = delta1; dL/dr (Cauchy) = r / (1 + r^2/c^2)
167-
# 匹配点 r=delta2: delta1 = delta2 / (1 + delta2^2/c^2)
168-
c_sq = huber_two_stage_delta**2 / (
169-
huber_two_stage_delta / huber_delta - 1.0
170-
)
171-
self.c_sq = c_sq
172-
# 2. 计算 L2/L1 在 delta2 处的值 (C_match)
173-
# C_match = delta1 * delta2 - 0.5 * delta1^2
174-
L_L1_at_delta2 = huber_delta * huber_two_stage_delta - 0.5 * huber_delta**2
175-
# 3. 计算 Cauchy 在 delta2 处的值 (L_Cauchy)
176-
L_Cauchy_at_delta2 = (
177-
0.5 * self.c_sq * math.log(1.0 + huber_two_stage_delta**2 / self.c_sq)
218+
assert len(self.huber_two_stage_delta) == len(self.huber_delta), (
219+
"When using two-stage Huber loss, the length of huber_two_stage_delta must match that of huber_delta."
178220
)
179-
# 4. 计算偏移量 K (K = C_match - L_Cauchy)
180-
# 保证 L3(r) = L_Cauchy(r) + K 在 delta2 处值连续
181-
self.k = L_L1_at_delta2 - L_Cauchy_at_delta2
221+
self.k1 = []
222+
self.k2 = []
223+
for i, delta in enumerate(self.huber_delta):
224+
two_stage_delta = self.huber_two_stage_delta[i]
225+
k1, k2 = (
226+
cauchy_params_from_huber(delta, two_stage_delta)
227+
if not self.huber_two_stage_use_root
228+
else root_params_from_huber(delta, two_stage_delta)
229+
)
230+
self.k1.append(k1)
231+
self.k2.append(k2)
182232
else:
183-
self.c_sq = 0.0
184-
self.k = 0.0
233+
self.k1 = [0.0]
234+
self.k2 = [0.0]
185235
if self.use_huber and (
186236
self.has_pf or self.has_gf or self.relative_f is not None
187237
):
@@ -259,6 +309,7 @@ def forward(
259309

260310
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
261311
more_loss = {}
312+
huber_index = 0
262313
# more_loss['log_keys'] = [] # showed when validation on the fly
263314
# more_loss['test_keys'] = [] # showed when doing dp test
264315
atom_norm = 1.0 / natoms
@@ -324,14 +375,17 @@ def forward(
324375
if not self.use_huber:
325376
loss += atom_norm * (pref_e * l2_ener_loss)
326377
else:
378+
used_index = min(huber_index, len(self.huber_delta) - 1)
327379
l_huber_loss = custom_huber_loss(
328380
atom_norm * model_pred["energy"],
329381
atom_norm * label["energy"],
330-
delta=self.huber_delta,
331-
two_stage_delta=self.huber_two_stage_delta,
332-
c_sq=self.c_sq,
333-
k=self.k,
382+
delta=self.huber_delta[used_index],
383+
delta2=self.huber_two_stage_delta[used_index],
384+
k1=self.k1[used_index],
385+
k2=self.k2[used_index],
386+
use_root=self.huber_two_stage_use_root,
334387
)
388+
huber_index += 1
335389
loss += pref_e * l_huber_loss
336390
rmse_e = l2_ener_loss.sqrt() * atom_norm
337391
more_loss["rmse_e"] = self.display_if_exist(
@@ -439,14 +493,17 @@ def forward(
439493
GLOBAL_PT_FLOAT_PRECISION
440494
)
441495
else:
496+
used_index = min(huber_index, len(self.huber_delta) - 1)
442497
l_huber_loss = custom_huber_loss(
443498
force_pred_reshape,
444499
force_label_reshape,
445-
delta=self.huber_delta,
446-
two_stage_delta=self.huber_two_stage_delta,
447-
c_sq=self.c_sq,
448-
k=self.k,
500+
delta=self.huber_delta[used_index],
501+
delta2=self.huber_two_stage_delta[used_index],
502+
k1=self.k1[used_index],
503+
k2=self.k2[used_index],
504+
use_root=self.huber_two_stage_use_root,
449505
)
506+
huber_index += 1
450507
loss += pref_f * l_huber_loss
451508
rmse_f = l2_force_loss.sqrt()
452509
more_loss["rmse_f"] = self.display_if_exist(
@@ -520,14 +577,17 @@ def forward(
520577
if not self.use_huber:
521578
loss += atom_norm * (pref_v * l2_virial_loss)
522579
else:
580+
used_index = min(huber_index, len(self.huber_delta) - 1)
523581
l_huber_loss = custom_huber_loss(
524582
atom_norm * model_pred["virial"].reshape(-1),
525583
atom_norm * label["virial"].reshape(-1),
526-
delta=self.huber_delta,
527-
two_stage_delta=self.huber_two_stage_delta,
528-
c_sq=self.c_sq,
529-
k=self.k,
584+
delta=self.huber_delta[used_index],
585+
delta2=self.huber_two_stage_delta[used_index],
586+
k1=self.k1[used_index],
587+
k2=self.k2[used_index],
588+
use_root=self.huber_two_stage_use_root,
530589
)
590+
huber_index += 1
531591
loss += pref_v * l_huber_loss
532592
rmse_v = l2_virial_loss.sqrt() * atom_norm
533593
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
@@ -552,14 +612,17 @@ def forward(
552612
if not self.use_huber:
553613
loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PT_FLOAT_PRECISION)
554614
else:
615+
used_index = min(huber_index, len(self.huber_delta) - 1)
555616
l_huber_loss = custom_huber_loss(
556617
atom_ener_reshape,
557618
atom_ener_label_reshape,
558-
delta=self.huber_delta,
559-
two_stage_delta=self.huber_two_stage_delta,
560-
c_sq=self.c_sq,
561-
k=self.k,
619+
delta=self.huber_delta[used_index],
620+
delta2=self.huber_two_stage_delta[used_index],
621+
k1=self.k1[used_index],
622+
k2=self.k2[used_index],
623+
use_root=self.huber_two_stage_use_root,
562624
)
625+
huber_index += 1
563626
loss += pref_ae * l_huber_loss
564627
rmse_ae = l2_atom_ener_loss.sqrt()
565628
more_loss["rmse_ae"] = self.display_if_exist(

deepmd/utils/argcheck.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,17 +2918,23 @@ def loss_ener() -> list[Argument]:
29182918
),
29192919
Argument(
29202920
"huber_delta",
2921-
float,
2921+
[float, list],
29222922
optional=True,
29232923
default=0.01,
29242924
doc=doc_huber_delta,
29252925
),
29262926
Argument(
29272927
"huber_two_stage_delta",
2928-
float,
2928+
[float, list],
29292929
optional=True,
29302930
default=None,
29312931
),
2932+
Argument(
2933+
"huber_two_stage_use_root",
2934+
bool,
2935+
optional=True,
2936+
default=False,
2937+
),
29322938
Argument(
29332939
"trimmed_factor",
29342940
float,

0 commit comments

Comments
 (0)