@@ -94,6 +94,9 @@ class GeneralFitting(NativeOP, BaseFitting):
9494 A list of strings. Give the name to each type of atoms.
9595 seed: Optional[Union[int, list[int]]]
9696 Random seed for initializing the network parameters.
97+ default_fparam: list[float], optional
98+ The default frame parameter. If set, when `fparam.npy` files are not included in the data system,
99+ this value will be used as the default value for the frame parameter in the fitting net.
97100 """
98101
99102 def __init__ (
@@ -120,6 +123,7 @@ def __init__(
120123 remove_vaccum_contribution : Optional [list [bool ]] = None ,
121124 type_map : Optional [list [str ]] = None ,
122125 seed : Optional [Union [int , list [int ]]] = None ,
126+ default_fparam : Optional [list [float ]] = None ,
123127 ) -> None :
124128 self .var_name = var_name
125129 self .ntypes = ntypes
@@ -129,6 +133,7 @@ def __init__(
129133 self .numb_fparam = numb_fparam
130134 self .numb_aparam = numb_aparam
131135 self .dim_case_embd = dim_case_embd
136+ self .default_fparam = default_fparam
132137 self .rcond = rcond
133138 self .tot_ener_zero = tot_ener_zero
134139 self .trainable = trainable
@@ -177,6 +182,15 @@ def __init__(
177182 self .case_embd = np .zeros (self .dim_case_embd , dtype = self .prec )
178183 else :
179184 self .case_embd = None
185+
186+ if self .default_fparam is not None :
187+ if self .numb_fparam > 0 :
188+ assert len (self .default_fparam ) == self .numb_fparam , (
189+ "default_fparam length mismatch!"
190+ )
191+ self .default_fparam_tensor = np .array (self .default_fparam , dtype = self .prec )
192+ else :
193+ self .default_fparam_tensor = None
180194 # init networks
181195 in_dim = (
182196 self .dim_descrpt
@@ -217,6 +231,10 @@ def get_dim_aparam(self) -> int:
217231 """Get the number (dimension) of atomic parameters of this atomic model."""
218232 return self .numb_aparam
219233
234+ def has_default_fparam (self ) -> bool :
235+ """Check if the fitting has default frame parameters."""
236+ return self .default_fparam is not None
237+
220238 def get_sel_type (self ) -> list [int ]:
221239 """Get the selected atom types of this model.
222240
@@ -315,6 +333,7 @@ def serialize(self) -> dict:
315333 "numb_fparam" : self .numb_fparam ,
316334 "numb_aparam" : self .numb_aparam ,
317335 "dim_case_embd" : self .dim_case_embd ,
336+ "default_fparam" : self .default_fparam ,
318337 "rcond" : self .rcond ,
319338 "activation_function" : self .activation_function ,
320339 "precision" : self .precision ,
@@ -403,6 +422,14 @@ def _call_common(
403422 xx_zeros = xp .zeros_like (xx )
404423 else :
405424 xx_zeros = None
425+
426+ if self .numb_fparam > 0 and fparam is None :
427+ # use default fparam
428+ assert self .default_fparam_tensor is not None
429+ fparam = xp .tile (
430+ xp .reshape (self .default_fparam_tensor , (1 , self .numb_fparam )), (nf , 1 )
431+ )
432+
406433 # check fparam dim, concate to input descriptor
407434 if self .numb_fparam > 0 :
408435 assert fparam is not None , "fparam should not be None"
0 commit comments