@@ -227,6 +227,7 @@ def __init__(
227227 remove_vaccum_contribution : Optional [list [bool ]] = None ,
228228 type_map : Optional [list [str ]] = None ,
229229 use_aparam_as_mask : bool = False ,
230+ default_fparam : Optional [list ] = None ,
230231 ** kwargs ,
231232 ) -> None :
232233 super ().__init__ ()
@@ -238,6 +239,7 @@ def __init__(
238239 self .resnet_dt = resnet_dt
239240 self .numb_fparam = numb_fparam
240241 self .numb_aparam = numb_aparam
242+ self .default_fparam = default_fparam
241243 self .dim_case_embd = dim_case_embd
242244 self .activation_function = activation_function
243245 self .precision = precision
@@ -299,6 +301,20 @@ def __init__(
299301 else :
300302 self .case_embd = None
301303
304+ if self .default_fparam is not None :
305+ if self .numb_fparam > 0 :
306+ assert (
307+ len (self .default_fparam ) == self .numb_fparam
308+ ), "default_fparam length mismatch!"
309+ self .register_buffer (
310+ "default_fparam_tensor" ,
311+ torch .tensor (
312+ np .array (self .default_fparam ), dtype = self .prec , device = device
313+ ),
314+ )
315+ else :
316+ self .default_fparam_tensor = None
317+
302318 in_dim = (
303319 self .dim_descrpt
304320 + self .numb_fparam
@@ -415,6 +431,9 @@ def get_dim_fparam(self) -> int:
415431 """Get the number (dimension) of frame parameters of this atomic model."""
416432 return self .numb_fparam
417433
434+ def has_default_fparam (self ) -> bool :
435+ return self .default_fparam is not None
436+
418437 def get_dim_aparam (self ) -> int :
419438 """Get the number (dimension) of atomic parameters of this atomic model."""
420439 return self .numb_aparam
@@ -509,6 +528,13 @@ def _forward_common(
509528 ):
510529 # cast the input to internal precsion
511530 xx = descriptor .to (self .prec )
531+ nf , nloc , nd = xx .shape
532+
533+ if self .numb_fparam > 0 and fparam is None :
534+ # use default fparam
535+ assert self .default_fparam_tensor is not None
536+ fparam = torch .tile (self .default_fparam_tensor .unsqueeze (0 ), [nf , 1 ])
537+
512538 fparam = fparam .to (self .prec ) if fparam is not None else None
513539 aparam = aparam .to (self .prec ) if aparam is not None else None
514540
@@ -521,7 +547,6 @@ def _forward_common(
521547 xx_zeros = torch .zeros_like (xx )
522548 else :
523549 xx_zeros = None
524- nf , nloc , nd = xx .shape
525550 net_dim_out = self ._net_out_dim ()
526551
527552 if nd != self .dim_descrpt :
0 commit comments