@@ -52,6 +52,8 @@ def __init__(
5252 axis_neuron : int = 4 ,
5353 update_angle : bool = True ,
5454 optim_update : bool = True ,
55+ use_dynamic_sel : bool = False ,
56+ sel_reduce_factor : float = 10.0 ,
5557 smooth_edge_update : bool = False ,
5658 activation_function : str = "silu" ,
5759 update_style : str = "res_residual" ,
@@ -98,6 +100,10 @@ def __init__(
98100 self .prec = PRECISION_DICT [precision ]
99101 self .optim_update = optim_update
100102 self .smooth_edge_update = smooth_edge_update
103+ self .use_dynamic_sel = use_dynamic_sel
104+ self .sel_reduce_factor = sel_reduce_factor
105+ self .dynamic_e_sel = self .nnei / self .sel_reduce_factor
106+ self .dynamic_a_sel = self .a_sel / self .sel_reduce_factor
101107
102108 assert update_residual_init in [
103109 "norm" ,
@@ -812,7 +818,7 @@ def serialize(self) -> dict:
812818 """
813819 data = {
814820 "@class" : "RepFlowLayer" ,
815- "@version" : 1 ,
821+ "@version" : 2 ,
816822 "e_rcut" : self .e_rcut ,
817823 "e_rcut_smth" : self .e_rcut_smth ,
818824 "e_sel" : self .e_sel ,
@@ -836,6 +842,8 @@ def serialize(self) -> dict:
836842 "precision" : self .precision ,
837843 "optim_update" : self .optim_update ,
838844 "smooth_edge_update" : self .smooth_edge_update ,
845+ "use_dynamic_sel" : self .use_dynamic_sel ,
846+ "sel_reduce_factor" : self .sel_reduce_factor ,
839847 "node_self_mlp" : self .node_self_mlp .serialize (),
840848 "node_sym_linear" : self .node_sym_linear .serialize (),
841849 "node_edge_linear" : self .node_edge_linear .serialize (),
@@ -878,7 +886,7 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
878886 The dict to deserialize from.
879887 """
880888 data = data .copy ()
881- check_version_compatibility (data .pop ("@version" ), 1 , 1 )
889+ check_version_compatibility (data .pop ("@version" ), 2 , 1 )
882890 data .pop ("@class" )
883891 update_angle = data ["update_angle" ]
884892 a_compress_rate = data ["a_compress_rate" ]
0 commit comments