@@ -77,6 +77,9 @@ def __init__(
7777 edge_attn_hidden : int = 32 ,
7878 edge_attn_head : int = 4 ,
7979 edge_attn_use_ln : bool = True ,
80+ edge_rbf_dot_self : bool = False ,
81+ edge_rbf_dot_message : bool = False ,
82+ rbf_dim : int = 8 ,
8083 activation_function : str = "silu" ,
8184 update_style : str = "res_residual" ,
8285 update_residual : float = 0.1 ,
@@ -152,6 +155,30 @@ def __init__(
152155 self .edge_attn_hidden = edge_attn_hidden
153156 self .edge_attn_head = edge_attn_head
154157 self .edge_attn_use_ln = edge_attn_use_ln
158+ self .edge_rbf_dot_self = edge_rbf_dot_self
159+ self .edge_rbf_dot_message = edge_rbf_dot_message
160+ self .rbf_dim = rbf_dim
161+
162+ if self .edge_rbf_dot_self or self .edge_rbf_dot_message :
163+ self .rbf_mlp = MLPLayer (
164+ rbf_dim ,
165+ self .e_dim ,
166+ precision = precision ,
167+ seed = child_seed (seed , 30 ),
168+ )
169+ else :
170+ self .rbf_mlp = None
171+
172+ if self .edge_rbf_dot_message :
173+ self .rbf_mlp_message = MLPLayer (
174+ rbf_dim ,
175+ self .n_dim ,
176+ precision = precision ,
177+ seed = child_seed (seed , 31 ),
178+ )
179+ else :
180+ self .rbf_mlp_message = None
181+
155182 if self .edge_use_attn :
156183 assert (
157184 not self .use_dynamic_sel
@@ -889,6 +916,7 @@ def forward(
889916 dihedral_index : Optional [torch .Tensor ] = None , # n_dihedral x 2
890917 dihedral_ebd : Optional [torch .Tensor ] = None , # n_dihedral x d_dim
891918 d_sw : Optional [torch .Tensor ] = None , # n_dihedral
919+ rbf_ebd : Optional [torch .Tensor ] = None , # n_edge x num_b
892920 ):
893921 """
894922 Parameters
@@ -962,6 +990,25 @@ def forward(
962990 )
963991 )
964992
993+ # handle edge rbf
994+ if self .edge_rbf_dot_self or self .edge_rbf_dot_message :
995+ assert rbf_ebd is not None
996+ assert self .rbf_mlp is not None
997+ edge_rbf = self .rbf_mlp (rbf_ebd )
998+ else :
999+ edge_rbf = None
1000+
1001+ if self .edge_rbf_dot_message :
1002+ assert rbf_ebd is not None
1003+ assert self .rbf_mlp_message is not None
1004+ edge_rbf_node = self .rbf_mlp_message (rbf_ebd )
1005+ else :
1006+ edge_rbf_node = None
1007+
1008+ if self .edge_rbf_dot_self :
1009+ assert edge_rbf is not None
1010+ edge_ebd = edge_ebd * edge_rbf
1011+
9651012 n_update_list : list [torch .Tensor ] = [node_ebd ]
9661013 e_update_list : list [torch .Tensor ] = [edge_ebd ]
9671014 a_update_list : list [torch .Tensor ] = [angle_ebd ]
@@ -1079,6 +1126,9 @@ def forward(
10791126 "node" ,
10801127 )
10811128 ) * sw .unsqueeze (- 1 )
1129+ if self .edge_rbf_dot_message :
1130+ assert edge_rbf_node is not None
1131+ node_edge_update = node_edge_update * edge_rbf_node
10821132 node_edge_update = (
10831133 (torch .sum (node_edge_update , dim = - 2 ) / self .nnei )
10841134 if not self .use_dynamic_sel
@@ -1132,6 +1182,9 @@ def forward(
11321182 "edge" ,
11331183 )
11341184 )
1185+ if self .edge_rbf_dot_message :
1186+ assert edge_rbf is not None
1187+ edge_self_update = edge_self_update * edge_rbf
11351188 e_update_list .append (edge_self_update )
11361189
11371190 # edge attention message
0 commit comments