Skip to content

Commit 7f547b8

Browse files
committed
add use_default_pf
1 parent e2777c0 commit 7f547b8

2 files changed

Lines changed: 13 additions & 2 deletions

File tree

deepmd/pt/loss/ener.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
use_l1_all: bool = False,
5555
inference=False,
5656
use_huber=False,
57+
use_default_pf=False,
5758
huber_delta=0.01,
5859
**kwargs,
5960
) -> None:
@@ -131,6 +132,7 @@ def __init__(
131132
self.limit_pref_pf = limit_pref_pf
132133
self.start_pref_gf = start_pref_gf
133134
self.limit_pref_gf = limit_pref_gf
135+
self.use_default_pf = use_default_pf
134136
self.relative_f = relative_f
135137
self.enable_atom_ener_coeff = enable_atom_ener_coeff
136138
self.numb_generalized_coord = numb_generalized_coord
@@ -301,7 +303,9 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
301303

302304
if self.has_pf and "atom_pref" in label:
303305
atom_pref = label["atom_pref"]
304-
find_atom_pref = label.get("find_atom_pref", 0.0)
306+
find_atom_pref = (
307+
label.get("find_atom_pref", 0.0) if not self.use_default_pf else 1.0
308+
)
305309
pref_pf = pref_pf * find_atom_pref
306310
atom_pref_reshape = atom_pref.reshape(-1)
307311
l2_pref_force_loss = (torch.square(diff_f) * atom_pref_reshape).mean()
@@ -410,7 +414,7 @@ def label_requirement(self) -> list[DataRequirementItem]:
410414
high_prec=True,
411415
)
412416
)
413-
if self.has_f:
417+
if self.has_f or self.has_pf or self.relative_f is not None or self.has_gf:
414418
label_requirement.append(
415419
DataRequirementItem(
416420
"force",
@@ -449,6 +453,7 @@ def label_requirement(self) -> list[DataRequirementItem]:
449453
must=False,
450454
high_prec=False,
451455
repeat=3,
456+
default=1.0,
452457
)
453458
)
454459
if self.has_gf > 0:

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2659,6 +2659,12 @@ def loss_ener():
26592659
default=0.00,
26602660
doc=doc_limit_pref_pf,
26612661
),
2662+
Argument(
2663+
"use_default_pf",
2664+
bool,
2665+
optional=True,
2666+
default=False,
2667+
),
26622668
Argument("relative_f", [float, None], optional=True, doc=doc_relative_f),
26632669
Argument(
26642670
"enable_atom_ener_coeff",

0 commit comments

Comments
 (0)