Skip to content

Commit ea77526

Browse files
committed
add trimmed_factor for loss
1 parent 6c95f37 commit ea77526

2 files changed

Lines changed: 23 additions & 3 deletions

File tree

deepmd/pt/loss/ener.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
inference: bool = False,
5959
use_huber: bool = False,
6060
huber_delta: float = 0.01,
61+
trimmed_factor: float = 0.0,
6162
**kwargs: Any,
6263
) -> None:
6364
r"""Construct a layer to compute loss on energy, force and virial.
@@ -151,6 +152,7 @@ def __init__(
151152
raise RuntimeError(
152153
"Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
153154
)
155+
self.trimmed_factor = trimmed_factor
154156

155157
def forward(
156158
self,
@@ -272,6 +274,16 @@ def forward(
272274
force_pred = model_pred["force"]
273275
force_label = label["force"]
274276
diff_f = (force_label - force_pred).reshape(-1)
277+
force_pred_reshape = force_pred.reshape(-1)
278+
force_label_reshape = force_label.reshape(-1)
279+
280+
if self.trimmed_factor > 0.0:
281+
num_samples = diff_f.numel()
282+
num_keep = int(num_samples * (1 - self.trimmed_factor))
283+
keep_values, mask = torch.topk(diff_f.abs(), k=num_keep, largest=False)
284+
diff_f = diff_f[mask]
285+
force_pred_reshape = force_pred_reshape[mask]
286+
force_label_reshape = force_label_reshape[mask]
275287

276288
if self.relative_f is not None:
277289
force_label_3 = force_label.reshape(-1, 3)
@@ -291,8 +303,8 @@ def forward(
291303
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
292304
else:
293305
l_huber_loss = custom_huber_loss(
294-
force_pred.reshape(-1),
295-
force_label.reshape(-1),
306+
force_pred_reshape,
307+
force_label_reshape,
296308
delta=self.huber_delta,
297309
)
298310
loss += pref_f * l_huber_loss
@@ -301,7 +313,9 @@ def forward(
301313
rmse_f.detach(), find_force
302314
)
303315
else:
304-
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
316+
l1_force_loss = F.l1_loss(
317+
force_label_reshape, force_pred_reshape, reduction="none"
318+
)
305319
more_loss["mae_f"] = self.display_if_exist(
306320
l1_force_loss.mean().detach(), find_force
307321
)

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,6 +2923,12 @@ def loss_ener() -> list[Argument]:
29232923
default=0.01,
29242924
doc=doc_huber_delta,
29252925
),
2926+
Argument(
2927+
"trimmed_factor",
2928+
float,
2929+
optional=True,
2930+
default=0.0,
2931+
),
29262932
]
29272933

29282934

0 commit comments

Comments
 (0)