@@ -54,6 +54,7 @@ def __init__(
5454 use_l1_all : bool = False ,
5555 inference = False ,
5656 use_huber = False ,
57+ use_default_pf = False ,
5758 huber_delta = 0.01 ,
5859 ** kwargs ,
5960 ) -> None :
@@ -131,6 +132,7 @@ def __init__(
131132 self .limit_pref_pf = limit_pref_pf
132133 self .start_pref_gf = start_pref_gf
133134 self .limit_pref_gf = limit_pref_gf
135+ self .use_default_pf = use_default_pf
134136 self .relative_f = relative_f
135137 self .enable_atom_ener_coeff = enable_atom_ener_coeff
136138 self .numb_generalized_coord = numb_generalized_coord
@@ -301,7 +303,9 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
301303
302304 if self .has_pf and "atom_pref" in label :
303305 atom_pref = label ["atom_pref" ]
304- find_atom_pref = label .get ("find_atom_pref" , 0.0 )
306+ find_atom_pref = (
307+ label .get ("find_atom_pref" , 0.0 ) if not self .use_default_pf else 1.0
308+ )
305309 pref_pf = pref_pf * find_atom_pref
306310 atom_pref_reshape = atom_pref .reshape (- 1 )
307311 l2_pref_force_loss = (torch .square (diff_f ) * atom_pref_reshape ).mean ()
@@ -410,7 +414,7 @@ def label_requirement(self) -> list[DataRequirementItem]:
410414 high_prec = True ,
411415 )
412416 )
413- if self .has_f :
417+ if self .has_f or self . has_pf or self . relative_f is not None or self . has_gf :
414418 label_requirement .append (
415419 DataRequirementItem (
416420 "force" ,
@@ -449,6 +453,7 @@ def label_requirement(self) -> list[DataRequirementItem]:
449453 must = False ,
450454 high_prec = False ,
451455 repeat = 3 ,
456+ default = 1.0 ,
452457 )
453458 )
454459 if self .has_gf > 0 :
0 commit comments