@@ -55,7 +55,8 @@ def __init__(
5555 axis_neuron : int = 4 ,
5656 update_angle : bool = True , # angle
5757 optim_update : bool = True ,
58- no_sel : bool = False ,
58+ use_dynamic_sel : bool = False ,
59+ sel_reduce_factor : float = 10.0 ,
5960 smooth_edge_update : bool = False ,
6061 activation_function : str = "silu" ,
6162 update_style : str = "res_residual" ,
@@ -102,7 +103,10 @@ def __init__(
102103 self .prec = PRECISION_DICT [precision ]
103104 self .optim_update = optim_update
104105 self .smooth_edge_update = smooth_edge_update
105- self .no_sel = no_sel
106+ self .use_dynamic_sel = use_dynamic_sel
107+ self .sel_reduce_factor = sel_reduce_factor
108+ self .dynamic_e_sel = self .nnei / self .sel_reduce_factor
109+ self .dynamic_a_sel = self .a_sel / self .sel_reduce_factor
106110
107111 assert update_residual_init in [
108112 "norm" ,
@@ -324,7 +328,7 @@ def _cal_hg(
324328 return h2g2
325329
326330 @staticmethod
327- def _cal_hg_nosel (
331+ def _cal_hg_dynamic (
328332 flat_edge_ebd : torch .Tensor ,
329333 flat_h2 : torch .Tensor ,
330334 flat_sw : torch .Tensor ,
@@ -447,7 +451,7 @@ def symmetrization_op(
447451 g1_13 = self ._cal_grrg (h2g2 , axis_neuron )
448452 return g1_13
449453
450- def symmetrization_op_nosel (
454+ def symmetrization_op_dynamic (
451455 self ,
452456 flat_edge_ebd : torch .Tensor ,
453457 flat_h2 : torch .Tensor ,
@@ -487,7 +491,7 @@ def symmetrization_op_nosel(
487491 Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim)
488492 """
489493 # nb x nloc x 3 x e_dim
490- h2g2 = self ._cal_hg_nosel (
494+ h2g2 = self ._cal_hg_dynamic (
491495 flat_edge_ebd ,
492496 flat_h2 ,
493497 flat_sw ,
@@ -552,7 +556,7 @@ def optim_angle_update(
552556 ) + bias
553557 return result_update
554558
555- def optim_angle_update_nosel (
559+ def optim_angle_update_dynamic (
556560 self ,
557561 flat_angle_ebd : torch .Tensor ,
558562 node_ebd : torch .Tensor ,
@@ -655,7 +659,7 @@ def optim_edge_update(
655659 ) + bias
656660 return result_update
657661
658- def optim_edge_update_nosel (
662+ def optim_edge_update_dynamic (
659663 self ,
660664 node_ebd : torch .Tensor ,
661665 node_ebd_ext : torch .Tensor ,
@@ -770,7 +774,7 @@ def forward(
770774 node_ebd , _ = torch .split (node_ebd_ext , [nloc , nall - nloc ], dim = 1 )
771775 n_edge = nlist_mask .sum ().item ()
772776 assert (nb , nloc ) == node_ebd .shape [:2 ]
773- if not self .no_sel :
777+ if not self .use_dynamic_sel :
774778 assert (nb , nloc , nnei , 3 ) == h2 .shape
775779 else :
776780 assert (n_edge , 3 ) == h2 .shape
@@ -786,7 +790,7 @@ def forward(
786790 # nb x nloc x nnei x n_dim [OR] n_edge x n_dim
787791 nei_node_ebd = (
788792 _make_nei_g1 (node_ebd_ext , nlist )
789- if not self .no_sel
793+ if not self .use_dynamic_sel
790794 else torch .index_select (
791795 node_ebd_ext .reshape (- 1 , self .n_dim ), 0 , n_ext2e_index
792796 )
@@ -810,15 +814,15 @@ def forward(
810814 sw ,
811815 self .axis_neuron ,
812816 )
813- if not self .no_sel
814- else self .symmetrization_op_nosel (
817+ if not self .use_dynamic_sel
818+ else self .symmetrization_op_dynamic (
815819 edge_ebd ,
816820 h2 ,
817821 sw ,
818822 owner = n2e_index ,
819823 num_owner = nb * nloc ,
820824 nloc = nloc ,
821- scale_factor = 1.0 / ( float ( nnei ) ** 0.5 ),
825+ scale_factor = self . dynamic_e_sel ** ( - 0.5 ),
822826 axis_neuron = self .axis_neuron ,
823827 )
824828 )
@@ -830,23 +834,23 @@ def forward(
830834 sw ,
831835 self .axis_neuron ,
832836 )
833- if not self .no_sel
834- else self .symmetrization_op_nosel (
837+ if not self .use_dynamic_sel
838+ else self .symmetrization_op_dynamic (
835839 nei_node_ebd ,
836840 h2 ,
837841 sw ,
838842 owner = n2e_index ,
839843 num_owner = nb * nloc ,
840844 nloc = nloc ,
841- scale_factor = 1.0 / ( float ( nnei ) ** 0.5 ),
845+ scale_factor = self . dynamic_e_sel ** ( - 0.5 ),
842846 axis_neuron = self .axis_neuron ,
843847 )
844848 )
845849 node_sym = self .act (self .node_sym_linear (torch .cat (node_sym_list , dim = - 1 )))
846850 n_update_list .append (node_sym )
847851
848852 if not self .optim_update :
849- if not self .no_sel :
853+ if not self .use_dynamic_sel :
850854 # nb x nloc x nnei x (n_dim * 2 + e_dim)
851855 edge_info = torch .cat (
852856 [
@@ -885,8 +889,8 @@ def forward(
885889 nlist ,
886890 "node" ,
887891 )
888- if not self .no_sel
889- else self .optim_edge_update_nosel (
892+ if not self .use_dynamic_sel
893+ else self .optim_edge_update_dynamic (
890894 node_ebd ,
891895 node_ebd_ext ,
892896 edge_ebd ,
@@ -897,15 +901,15 @@ def forward(
897901 )
898902 node_edge_update = (
899903 (torch .sum (node_edge_update * sw .unsqueeze (- 1 ), dim = - 2 ) / self .nnei )
900- if not self .no_sel
904+ if not self .use_dynamic_sel
901905 else (
902906 aggregate (
903907 node_edge_update * sw .unsqueeze (- 1 ),
904908 n2e_index ,
905909 average = False ,
906910 num_owner = nb * nloc ,
907911 ).reshape (nb , nloc , - 1 )
908- / self .nnei
912+ / self .dynamic_e_sel
909913 )
910914 )
911915
@@ -934,8 +938,8 @@ def forward(
934938 nlist ,
935939 "edge" ,
936940 )
937- if not self .no_sel
938- else self .optim_edge_update_nosel (
941+ if not self .use_dynamic_sel
942+ else self .optim_edge_update_dynamic (
939943 node_ebd ,
940944 node_ebd_ext ,
941945 edge_ebd ,
@@ -965,7 +969,7 @@ def forward(
965969 node_ebd_for_angle = node_ebd
966970 edge_ebd_for_angle = edge_ebd
967971
968- if not self .no_sel :
972+ if not self .use_dynamic_sel :
969973 # nb x nloc x a_nnei x e_dim
970974 edge_ebd_for_angle = edge_ebd_for_angle [:, :, : self .a_sel , :]
971975 # nb x nloc x a_nnei x e_dim
@@ -979,7 +983,7 @@ def forward(
979983 node_ebd_for_angle .unsqueeze (2 ).unsqueeze (2 ),
980984 (1 , 1 , self .a_sel , self .a_sel , 1 ),
981985 )
982- if not self .no_sel
986+ if not self .use_dynamic_sel
983987 else torch .index_select (
984988 node_ebd_for_angle .reshape (- 1 , self .n_a_compress_dim ),
985989 0 ,
@@ -992,15 +996,15 @@ def forward(
992996 torch .tile (
993997 edge_ebd_for_angle .unsqueeze (2 ), (1 , 1 , self .a_sel , 1 , 1 )
994998 )
995- if not self .no_sel
999+ if not self .use_dynamic_sel
9961000 else torch .index_select (edge_ebd_for_angle , 0 , eik2a_index )
9971001 )
9981002 # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim
9991003 edge_for_angle_j = (
10001004 torch .tile (
10011005 edge_ebd_for_angle .unsqueeze (3 ), (1 , 1 , 1 , self .a_sel , 1 )
10021006 )
1003- if not self .no_sel
1007+ if not self .use_dynamic_sel
10041008 else torch .index_select (edge_ebd_for_angle , 0 , eij2a_index )
10051009 )
10061010 # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim)
@@ -1030,8 +1034,8 @@ def forward(
10301034 edge_ebd_for_angle ,
10311035 "edge" ,
10321036 )
1033- if not self .no_sel
1034- else self .optim_angle_update_nosel (
1037+ if not self .use_dynamic_sel
1038+ else self .optim_angle_update_dynamic (
10351039 angle_ebd ,
10361040 node_ebd_for_angle ,
10371041 edge_ebd_for_angle ,
@@ -1042,7 +1046,7 @@ def forward(
10421046 )
10431047 )
10441048
1045- if not self .no_sel :
1049+ if not self .use_dynamic_sel :
10461050 # nb x nloc x a_nnei x a_nnei x e_dim
10471051 weighted_edge_angle_update = (
10481052 edge_angle_update
@@ -1074,13 +1078,13 @@ def forward(
10741078 eij2a_index ,
10751079 average = False ,
10761080 num_owner = n_edge ,
1077- ) / (self .a_sel ** 0.5 )
1081+ ) / (self .dynamic_a_sel ** 0.5 )
10781082 if not self .smooth_edge_update :
10791083 # will be deprecated in the future
10801084 # not support dynamic index, will pass anyway
1081- if self .no_sel :
1085+ if self .use_dynamic_sel :
10821086 raise NotImplementedError (
1083- "smooth_edge_update must be True when using dynamic_sel !"
1087+ "smooth_edge_update must be True when use_dynamic_sel is True !"
10841088 )
10851089 full_mask = torch .concat (
10861090 [
@@ -1115,8 +1119,8 @@ def forward(
11151119 edge_ebd_for_angle ,
11161120 "angle" ,
11171121 )
1118- if not self .no_sel
1119- else self .optim_angle_update_nosel (
1122+ if not self .use_dynamic_sel
1123+ else self .optim_angle_update_dynamic (
11201124 angle_ebd ,
11211125 node_ebd_for_angle ,
11221126 edge_ebd_for_angle ,
0 commit comments