@@ -145,6 +145,7 @@ def __init__(
145145 remove_vaccum_contribution : Optional [list [bool ]] = None ,
146146 type_map : Optional [list [str ]] = None ,
147147 use_aparam_as_mask : bool = False ,
148+ default_fparam : Optional [list ] = None ,
148149 ** kwargs ,
149150 ) -> None :
150151 super ().__init__ ()
@@ -156,6 +157,7 @@ def __init__(
156157 self .resnet_dt = resnet_dt
157158 self .numb_fparam = numb_fparam
158159 self .numb_aparam = numb_aparam
160+ self .default_fparam = default_fparam
159161 self .dim_case_embd = dim_case_embd
160162 self .activation_function = activation_function
161163 self .precision = precision
@@ -217,6 +219,20 @@ def __init__(
217219 else :
218220 self .case_embd = None
219221
222+ if self .default_fparam is not None :
223+ if self .numb_fparam > 0 :
224+ assert (
225+ len (self .default_fparam ) == self .numb_fparam
226+ ), "default_fparam length mismatch!"
227+ self .register_buffer (
228+ "default_fparam_tensor" ,
229+ torch .tensor (
230+ np .array (self .default_fparam ), dtype = self .prec , device = device
231+ ),
232+ )
233+ else :
234+ self .default_fparam_tensor = None
235+
220236 in_dim = (
221237 self .dim_descrpt
222238 + self .numb_fparam
@@ -333,6 +349,9 @@ def get_dim_fparam(self) -> int:
333349 """Get the number (dimension) of frame parameters of this atomic model."""
334350 return self .numb_fparam
335351
352+ def has_default_fparam (self ) -> bool :
353+ return self .default_fparam is not None
354+
336355 def get_dim_aparam (self ) -> int :
337356 """Get the number (dimension) of atomic parameters of this atomic model."""
338357 return self .numb_aparam
@@ -427,6 +446,13 @@ def _forward_common(
427446 ):
428447 # cast the input to internal precsion
429448 xx = descriptor .to (self .prec )
449+ nf , nloc , nd = xx .shape
450+
451+ if self .numb_fparam > 0 and fparam is None :
452+ # use default fparam
453+ assert self .default_fparam_tensor is not None
454+ fparam = torch .tile (self .default_fparam_tensor .unsqueeze (0 ), [nf , 1 ])
455+
430456 fparam = fparam .to (self .prec ) if fparam is not None else None
431457 aparam = aparam .to (self .prec ) if aparam is not None else None
432458
@@ -439,7 +465,6 @@ def _forward_common(
439465 xx_zeros = torch .zeros_like (xx )
440466 else :
441467 xx_zeros = None
442- nf , nloc , nd = xx .shape
443468 net_dim_out = self ._net_out_dim ()
444469
445470 if nd != self .dim_descrpt :
0 commit comments