@@ -88,6 +88,7 @@ def __init__(
8888 use_gated_mlp : bool = False ,
8989 gated_mlp_norm : str = "none" ,
9090 node_use_rmsnorm : bool = False ,
91+ angle_use_node : bool = True ,
9192 activation_function : str = "silu" ,
9293 update_style : str = "res_residual" ,
9394 update_residual : float = 0.1 ,
@@ -184,6 +185,8 @@ def __init__(
184185 else :
185186 self .node_rmsnorm = None
186187
188+ self .angle_use_node = angle_use_node
189+
187190 if self .edge_rbf_dot_self or self .edge_rbf_dot_message :
188191 self .rbf_mlp = MLPLayer (
189192 rbf_dim ,
@@ -380,16 +383,22 @@ def __init__(
380383 self .angle_dim = self .a_dim
381384 if self .a_compress_rate == 0 :
382385 # angle + node + edge * 2
383- self .angle_dim += self .n_dim + 2 * self .e_dim
386+ self .angle_dim += (
387+ self .n_dim + 2 * self .e_dim
388+ if self .angle_use_node
389+ else 2 * self .e_dim
390+ )
384391 self .a_compress_n_linear = None
385392 self .a_compress_e_linear = None
386393 self .e_a_compress_dim = e_dim
387394 self .n_a_compress_dim = n_dim
388395 else :
389396 # angle + a_dim/c + a_dim/2c * 2 * e_rate
390- self .angle_dim += (1 + self .a_compress_e_rate ) * (
391- self .a_dim // self .a_compress_rate
392- )
397+ self .angle_dim += (
398+ (1 + self .a_compress_e_rate )
399+ if self .angle_use_node
400+ else self .a_compress_e_rate
401+ ) * (self .a_dim // self .a_compress_rate )
393402 self .e_a_compress_dim = (
394403 self .a_dim // (2 * self .a_compress_rate ) * self .a_compress_e_rate
395404 )
@@ -1383,20 +1392,6 @@ def forward(
13831392 a_nlist_mask .unsqueeze (- 1 ), edge_ebd_for_angle , 0.0
13841393 )
13851394 if not self .optim_update :
1386- # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
1387- node_for_angle_info = (
1388- torch .tile (
1389- node_ebd_for_angle .unsqueeze (2 ).unsqueeze (2 ),
1390- (1 , 1 , self .a_sel , self .a_sel , 1 ),
1391- )
1392- if not self .use_dynamic_sel
1393- else torch .index_select (
1394- node_ebd_for_angle .reshape (- 1 , self .n_a_compress_dim ),
1395- 0 ,
1396- n2a_index ,
1397- )
1398- )
1399-
14001395 # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim
14011396 edge_for_angle_k = (
14021397 torch .tile (
@@ -1418,7 +1413,21 @@ def forward(
14181413 [edge_for_angle_k , edge_for_angle_j ], dim = - 1
14191414 )
14201415 angle_info_list = [angle_ebd ]
1421- angle_info_list .append (node_for_angle_info )
1416+ if self .angle_use_node :
1417+ # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
1418+ node_for_angle_info = (
1419+ torch .tile (
1420+ node_ebd_for_angle .unsqueeze (2 ).unsqueeze (2 ),
1421+ (1 , 1 , self .a_sel , self .a_sel , 1 ),
1422+ )
1423+ if not self .use_dynamic_sel
1424+ else torch .index_select (
1425+ node_ebd_for_angle .reshape (- 1 , self .n_a_compress_dim ),
1426+ 0 ,
1427+ n2a_index ,
1428+ )
1429+ )
1430+ angle_info_list .append (node_for_angle_info )
14221431 angle_info_list .append (edge_for_angle_info )
14231432 # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c)
14241433 # [OR]
0 commit comments