33from typing import (
44 Any ,
55 Optional ,
6+ Union ,
67)
78
89import torch
@@ -29,21 +30,67 @@ def custom_huber_loss(
2930 predictions : torch .Tensor ,
3031 targets : torch .Tensor ,
3132 delta : float = 1.0 ,
32- two_stage_delta : float = 1.0 ,
33- c_sq : float = 0.0 ,
34- k : float = 0.0 ,
33+ delta2 : float = 1.0 ,
34+ k1 : float = 0.0 ,
35+ k2 : float = 0.0 ,
36+ use_root : bool = False ,
3537) -> torch .Tensor :
3638 error = targets - predictions
3739 abs_error = torch .abs (error )
3840 quadratic_loss = 0.5 * torch .pow (error , 2 )
3941 linear_loss = delta * (abs_error - 0.5 * delta )
4042 loss = torch .where (abs_error <= delta , quadratic_loss , linear_loss )
41- if two_stage_delta is not None and two_stage_delta > delta :
42- cauchy_loss = 0.5 * c_sq * torch .log1p (abs_error ** 2 / c_sq ) + k
43- loss = torch .where (abs_error > two_stage_delta , cauchy_loss , loss )
43+ if delta2 is not None and delta2 > delta :
44+ if not use_root :
45+ stage2_loss = 0.5 * k1 * torch .log1p (abs_error ** 2 / k1 ) + k2
46+ else :
47+ stage2_loss = k1 * torch .sqrt (abs_error ) + k2
48+ loss = torch .where (abs_error > delta2 , stage2_loss , loss )
4449 return torch .mean (loss )
4550
4651
52+ def cauchy_params_from_huber (
53+ delta1 : float ,
54+ delta2 : float ,
55+ ) -> tuple [float , float ]:
56+ assert delta2 >= delta1
57+ # --- 保证 L1 -> Cauchy 连续性所需的参数计算 ---
58+ # 1. 根据 C1 连续性 (梯度匹配) 求解 Cauchy 尺度 c^2
59+ # 目标: dL/dr (L1) = delta1; dL/dr (Cauchy) = r / (1 + r^2/c^2)
60+ # 匹配点 r=delta2: delta1 = delta2 / (1 + delta2^2/c^2)
61+ c_sq = delta2 ** 2 / (delta2 / delta1 - 1.0 )
62+ # 2. 计算 L2/L1 在 delta2 处的值 (C_match)
63+ # C_match = delta1 * delta2 - 0.5 * delta1^2
64+ L_L1_at_delta2 = delta1 * delta2 - 0.5 * delta1 ** 2
65+ # 3. 计算 Cauchy 在 delta2 处的值 (L_Cauchy)
66+ L_Cauchy_at_delta2 = 0.5 * c_sq * math .log (1.0 + delta2 ** 2 / c_sq )
67+ # 4. 计算偏移量 K (K = C_match - L_Cauchy)
68+ # 保证 L3(r) = L_Cauchy(r) + K 在 delta2 处值连续
69+ k = L_L1_at_delta2 - L_Cauchy_at_delta2
70+ return c_sq , k
71+
72+
73+ def root_params_from_huber (
74+ delta1 : float ,
75+ delta2 : float ,
76+ ) -> tuple [float , float ]:
77+ assert delta2 >= delta1
78+ # 1. 计算 K1 (缩放系数) - 保证 C1 连续性 (梯度匹配)
79+ # 目标梯度 dL/dr (L1) = delta1
80+ # L0.5 梯度 dL/dr (L0.5) = 0.5 * K1 * r^(-0.5)
81+ # 在 r=delta2 匹配: delta1 = 0.5 * K1 * delta2^(-0.5)
82+ # 解得 K1 = 2 * delta1 * sqrt(delta2)
83+ K1 = 2.0 * delta1 * math .sqrt (delta2 )
84+ # 2. 计算 L1 在 delta2 处的值 (L_L1)
85+ L_L1_at_delta2 = delta1 * delta2 - 0.5 * delta1 ** 2
86+ # 3. 计算 K2 (偏移量) - 保证 C0 连续性 (值匹配)
87+ # L_L1(delta2) = K1 * sqrt(delta2) + K2
88+ # 解得 K2 = L_L1(delta2) - K1 * sqrt(delta2)
89+ K2 = L_L1_at_delta2 - K1 * math .sqrt (delta2 )
90+ # 简化后: K2 = - delta1 * delta2 - 0.5 * delta1^2
91+ return K1 , K2
92+
93+
4794class EnergyStdLoss (TaskLoss ):
4895 def __init__ (
4996 self ,
@@ -66,11 +113,12 @@ def __init__(
66113 use_l1_all : bool = False ,
67114 inference : bool = False ,
68115 use_huber : bool = False ,
69- huber_delta : float = 0.01 ,
70- huber_two_stage_delta : Optional [float ] = None ,
116+ huber_delta : Union [ float , list [ float ]] = 0.01 ,
117+ huber_two_stage_delta : Optional [Union [ float , list [ float ]] ] = None ,
71118 trimmed_factor : float = 0.0 ,
72119 adaptive_loss : bool = False ,
73120 learnable_pref : bool = False ,
121+ huber_two_stage_use_root : bool = False ,
74122 ** kwargs : Any ,
75123 ) -> None :
76124 r"""Construct a layer to compute loss on energy, force and virial.
@@ -157,31 +205,33 @@ def __init__(
157205 self .use_l1_all = use_l1_all
158206 self .inference = inference
159207 self .use_huber = use_huber
160- self .huber_delta = huber_delta
161- self .huber_two_stage_delta = huber_two_stage_delta
208+ self .huber_delta = (
209+ [huber_delta ] if isinstance (huber_delta , float ) else huber_delta
210+ )
211+ self .huber_two_stage_delta = (
212+ [huber_two_stage_delta ]
213+ if isinstance (huber_two_stage_delta , float )
214+ else huber_two_stage_delta
215+ )
216+ self .huber_two_stage_use_root = huber_two_stage_use_root
162217 if self .use_huber and self .huber_two_stage_delta is not None :
163- assert huber_two_stage_delta >= huber_delta
164- # --- 保证 L1 -> Cauchy 连续性所需的参数计算 ---
165- # 1. 根据 C1 连续性 (梯度匹配) 求解 Cauchy 尺度 c^2
166- # 目标: dL/dr (L1) = delta1; dL/dr (Cauchy) = r / (1 + r^2/c^2)
167- # 匹配点 r=delta2: delta1 = delta2 / (1 + delta2^2/c^2)
168- c_sq = huber_two_stage_delta ** 2 / (
169- huber_two_stage_delta / huber_delta - 1.0
170- )
171- self .c_sq = c_sq
172- # 2. 计算 L2/L1 在 delta2 处的值 (C_match)
173- # C_match = delta1 * delta2 - 0.5 * delta1^2
174- L_L1_at_delta2 = huber_delta * huber_two_stage_delta - 0.5 * huber_delta ** 2
175- # 3. 计算 Cauchy 在 delta2 处的值 (L_Cauchy)
176- L_Cauchy_at_delta2 = (
177- 0.5 * self .c_sq * math .log (1.0 + huber_two_stage_delta ** 2 / self .c_sq )
218+ assert len (self .huber_two_stage_delta ) == len (self .huber_delta ), (
219+ "When using two-stage Huber loss, the length of huber_two_stage_delta must match that of huber_delta."
178220 )
179- # 4. 计算偏移量 K (K = C_match - L_Cauchy)
180- # 保证 L3(r) = L_Cauchy(r) + K 在 delta2 处值连续
181- self .k = L_L1_at_delta2 - L_Cauchy_at_delta2
221+ self .k1 = []
222+ self .k2 = []
223+ for i , delta in enumerate (self .huber_delta ):
224+ two_stage_delta = self .huber_two_stage_delta [i ]
225+ k1 , k2 = (
226+ cauchy_params_from_huber (delta , two_stage_delta )
227+ if not self .huber_two_stage_use_root
228+ else root_params_from_huber (delta , two_stage_delta )
229+ )
230+ self .k1 .append (k1 )
231+ self .k2 .append (k2 )
182232 else :
183- self .c_sq = 0.0
184- self .k = 0.0
233+ self .k1 = [ 0.0 ]
234+ self .k2 = [ 0.0 ]
185235 if self .use_huber and (
186236 self .has_pf or self .has_gf or self .relative_f is not None
187237 ):
@@ -259,6 +309,7 @@ def forward(
259309
260310 loss = torch .zeros (1 , dtype = env .GLOBAL_PT_FLOAT_PRECISION , device = env .DEVICE )[0 ]
261311 more_loss = {}
312+ huber_index = 0
262313 # more_loss['log_keys'] = [] # showed when validation on the fly
263314 # more_loss['test_keys'] = [] # showed when doing dp test
264315 atom_norm = 1.0 / natoms
@@ -324,14 +375,17 @@ def forward(
324375 if not self .use_huber :
325376 loss += atom_norm * (pref_e * l2_ener_loss )
326377 else :
378+ used_index = min (huber_index , len (self .huber_delta ) - 1 )
327379 l_huber_loss = custom_huber_loss (
328380 atom_norm * model_pred ["energy" ],
329381 atom_norm * label ["energy" ],
330- delta = self .huber_delta ,
331- two_stage_delta = self .huber_two_stage_delta ,
332- c_sq = self .c_sq ,
333- k = self .k ,
382+ delta = self .huber_delta [used_index ],
383+ delta2 = self .huber_two_stage_delta [used_index ],
384+ k1 = self .k1 [used_index ],
385+ k2 = self .k2 [used_index ],
386+ use_root = self .huber_two_stage_use_root ,
334387 )
388+ huber_index += 1
335389 loss += pref_e * l_huber_loss
336390 rmse_e = l2_ener_loss .sqrt () * atom_norm
337391 more_loss ["rmse_e" ] = self .display_if_exist (
@@ -439,14 +493,17 @@ def forward(
439493 GLOBAL_PT_FLOAT_PRECISION
440494 )
441495 else :
496+ used_index = min (huber_index , len (self .huber_delta ) - 1 )
442497 l_huber_loss = custom_huber_loss (
443498 force_pred_reshape ,
444499 force_label_reshape ,
445- delta = self .huber_delta ,
446- two_stage_delta = self .huber_two_stage_delta ,
447- c_sq = self .c_sq ,
448- k = self .k ,
500+ delta = self .huber_delta [used_index ],
501+ delta2 = self .huber_two_stage_delta [used_index ],
502+ k1 = self .k1 [used_index ],
503+ k2 = self .k2 [used_index ],
504+ use_root = self .huber_two_stage_use_root ,
449505 )
506+ huber_index += 1
450507 loss += pref_f * l_huber_loss
451508 rmse_f = l2_force_loss .sqrt ()
452509 more_loss ["rmse_f" ] = self .display_if_exist (
@@ -520,14 +577,17 @@ def forward(
520577 if not self .use_huber :
521578 loss += atom_norm * (pref_v * l2_virial_loss )
522579 else :
580+ used_index = min (huber_index , len (self .huber_delta ) - 1 )
523581 l_huber_loss = custom_huber_loss (
524582 atom_norm * model_pred ["virial" ].reshape (- 1 ),
525583 atom_norm * label ["virial" ].reshape (- 1 ),
526- delta = self .huber_delta ,
527- two_stage_delta = self .huber_two_stage_delta ,
528- c_sq = self .c_sq ,
529- k = self .k ,
584+ delta = self .huber_delta [used_index ],
585+ delta2 = self .huber_two_stage_delta [used_index ],
586+ k1 = self .k1 [used_index ],
587+ k2 = self .k2 [used_index ],
588+ use_root = self .huber_two_stage_use_root ,
530589 )
590+ huber_index += 1
531591 loss += pref_v * l_huber_loss
532592 rmse_v = l2_virial_loss .sqrt () * atom_norm
533593 more_loss ["rmse_v" ] = self .display_if_exist (rmse_v .detach (), find_virial )
@@ -552,14 +612,17 @@ def forward(
552612 if not self .use_huber :
553613 loss += (pref_ae * l2_atom_ener_loss ).to (GLOBAL_PT_FLOAT_PRECISION )
554614 else :
615+ used_index = min (huber_index , len (self .huber_delta ) - 1 )
555616 l_huber_loss = custom_huber_loss (
556617 atom_ener_reshape ,
557618 atom_ener_label_reshape ,
558- delta = self .huber_delta ,
559- two_stage_delta = self .huber_two_stage_delta ,
560- c_sq = self .c_sq ,
561- k = self .k ,
619+ delta = self .huber_delta [used_index ],
620+ delta2 = self .huber_two_stage_delta [used_index ],
621+ k1 = self .k1 [used_index ],
622+ k2 = self .k2 [used_index ],
623+ use_root = self .huber_two_stage_use_root ,
562624 )
625+ huber_index += 1
563626 loss += pref_ae * l_huber_loss
564627 rmse_ae = l2_atom_ener_loss .sqrt ()
565628 more_loss ["rmse_ae" ] = self .display_if_exist (
0 commit comments