@@ -123,6 +123,26 @@ class RepFlowArgs:
123123 smooth_edge_update : bool, optional
124124 Whether to make edge update smooth.
125125 If True, the edge update from angle message will not use self as padding.
126+ use_exp_switch : bool, optional
127+ Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
128+ The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
129+ `r` approaches the cutoff radius `rcut`. Specifically, the function is defined as:
130+ s(r) = \\exp(-\\exp(20 * (r - rcut_smth) / rcut_smth)) for 0 < r \\leq rcut, and s(r) = 0 for r > rcut.
131+ Here, `rcut_smth` is an adjustable smoothing factor and `rcut_smth` should be chosen carefully
132+ according to `rcut`, ensuring s(r) approaches zero smoothly at the cutoff.
133+ Typical recommended values are `rcut_smth` = 5.3 for `rcut` = 6.0, and 3.5 for `rcut` = 4.0.
134+ use_dynamic_sel : bool, optional
135+ Whether to dynamically select neighbors within the cutoff radius.
136+ If True, the exact number of neighbors within the cutoff radius is used
137+ without padding to a fixed selection numbers.
138+ When enabled, users can safely set larger values for `e_sel` or `a_sel` (e.g., 1200 or 300, respectively)
139+ to guarantee capturing all neighbors within the cutoff radius.
140+ Note that when using dynamic selection, the `smooth_edge_update` must be True.
141+ sel_reduce_factor : float, optional
142+ Reduction factor applied to neighbor-scale normalization when `use_dynamic_sel` is True.
143+ In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
144+ or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
145+ accommodating larger selection numbers.
126146 """
127147
128148 def __init__ (
@@ -150,6 +170,9 @@ def __init__(
150170 skip_stat : bool = False ,
151171 optim_update : bool = True ,
152172 smooth_edge_update : bool = False ,
173+ use_exp_switch : bool = False ,
174+ use_dynamic_sel : bool = False ,
175+ sel_reduce_factor : float = 10.0 ,
153176 ) -> None :
154177 self .n_dim = n_dim
155178 self .e_dim = e_dim
@@ -176,6 +199,9 @@ def __init__(
176199 self .a_compress_use_split = a_compress_use_split
177200 self .optim_update = optim_update
178201 self .smooth_edge_update = smooth_edge_update
202+ self .use_exp_switch = use_exp_switch
203+ self .use_dynamic_sel = use_dynamic_sel
204+ self .sel_reduce_factor = sel_reduce_factor
179205
180206 def __getitem__ (self , key ):
181207 if hasattr (self , key ):
@@ -207,6 +233,9 @@ def serialize(self) -> dict:
207233 "fix_stat_std" : self .fix_stat_std ,
208234 "optim_update" : self .optim_update ,
209235 "smooth_edge_update" : self .smooth_edge_update ,
236+ "use_exp_switch" : self .use_exp_switch ,
237+ "use_dynamic_sel" : self .use_dynamic_sel ,
238+ "sel_reduce_factor" : self .sel_reduce_factor ,
210239 }
211240
212241 @classmethod
@@ -303,6 +332,9 @@ def init_subclass_params(sub_data, sub_class):
303332 fix_stat_std = self .repflow_args .fix_stat_std ,
304333 optim_update = self .repflow_args .optim_update ,
305334 smooth_edge_update = self .repflow_args .smooth_edge_update ,
335+ use_exp_switch = self .repflow_args .use_exp_switch ,
336+ use_dynamic_sel = self .repflow_args .use_dynamic_sel ,
337+ sel_reduce_factor = self .repflow_args .sel_reduce_factor ,
306338 exclude_types = exclude_types ,
307339 env_protection = env_protection ,
308340 precision = precision ,
0 commit comments