@@ -58,6 +58,7 @@ def __init__(
5858 inference : bool = False ,
5959 use_huber : bool = False ,
6060 huber_delta : float = 0.01 ,
61+ trimmed_factor : float = 0.0 ,
6162 ** kwargs : Any ,
6263 ) -> None :
6364 r"""Construct a layer to compute loss on energy, force and virial.
@@ -151,6 +152,7 @@ def __init__(
151152 raise RuntimeError (
152153 "Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
153154 )
155+ self .trimmed_factor = trimmed_factor
154156
155157 def forward (
156158 self ,
@@ -272,6 +274,16 @@ def forward(
272274 force_pred = model_pred ["force" ]
273275 force_label = label ["force" ]
274276 diff_f = (force_label - force_pred ).reshape (- 1 )
277+ force_pred_reshape = force_pred .reshape (- 1 )
278+ force_label_reshape = force_label .reshape (- 1 )
279+
280+ if self .trimmed_factor > 0.0 :
281+ num_samples = diff_f .numel ()
282+ num_keep = int (num_samples * (1 - self .trimmed_factor ))
283+ keep_values , mask = torch .topk (diff_f .abs (), k = num_keep , largest = False )
284+ diff_f = diff_f [mask ]
285+ force_pred_reshape = force_pred_reshape [mask ]
286+ force_label_reshape = force_label_reshape [mask ]
275287
276288 if self .relative_f is not None :
277289 force_label_3 = force_label .reshape (- 1 , 3 )
@@ -291,8 +303,8 @@ def forward(
291303 loss += (pref_f * l2_force_loss ).to (GLOBAL_PT_FLOAT_PRECISION )
292304 else :
293305 l_huber_loss = custom_huber_loss (
294- force_pred . reshape ( - 1 ) ,
295- force_label . reshape ( - 1 ) ,
306+ force_pred_reshape ,
307+ force_label_reshape ,
296308 delta = self .huber_delta ,
297309 )
298310 loss += pref_f * l_huber_loss
@@ -301,7 +313,9 @@ def forward(
301313 rmse_f .detach (), find_force
302314 )
303315 else :
304- l1_force_loss = F .l1_loss (force_label , force_pred , reduction = "none" )
316+ l1_force_loss = F .l1_loss (
317+ force_label_reshape , force_pred_reshape , reduction = "none"
318+ )
305319 more_loss ["mae_f" ] = self .display_if_exist (
306320 l1_force_loss .mean ().detach (), find_force
307321 )
0 commit comments