@@ -96,6 +96,7 @@ def __init__(
9696 edge_rbf_cat_message : bool = False ,
9797 edge_message_use_dropout : bool = False ,
9898 angle_message_use_dropout : bool = False ,
99+ EN_use_NGA : bool = False ,
99100 dropout_rate : float = 0.1 ,
100101 activation_function : str = "silu" ,
101102 update_style : str = "res_residual" ,
@@ -146,6 +147,9 @@ def __init__(
146147 self .sel_reduce_factor = sel_reduce_factor
147148 self .dynamic_e_sel = self .nnei / self .sel_reduce_factor
148149 self .dynamic_a_sel = self .a_sel / self .sel_reduce_factor
150+ self .EN_use_NGA = EN_use_NGA
151+ if self .EN_use_NGA :
152+ assert not self .use_dynamic_sel , "NGA does not support dynamic selection!"
149153
150154 self .update_dihedral = update_dihedral
151155 self .d_dim = d_dim
@@ -375,6 +379,35 @@ def __init__(
375379 )
376380 residual_idx += 1
377381
382+ # node edge NGA
383+ if self .EN_use_NGA :
384+ self .edge_nga_mlp = MLPLayer (
385+ e_dim ,
386+ n_dim ,
387+ precision = precision ,
388+ seed = child_seed (seed , 20 ),
389+ )
390+ self .node_nga_mlp = MLPLayer (
391+ 2 * n_dim ,
392+ n_dim ,
393+ precision = precision ,
394+ seed = child_seed (seed , 21 ),
395+ )
396+ if self .update_style == "res_residual" :
397+ self .n_residual .append (
398+ get_residual (
399+ n_dim ,
400+ self .update_residual * self .residual_pref [residual_idx ],
401+ self .update_residual_init ,
402+ precision = precision ,
403+ seed = child_seed (seed , 22 ),
404+ )
405+ )
406+ residual_idx += 1
407+ else :
408+ self .edge_nga_mlp = None
409+ self .node_nga_mlp = None
410+
378411 # edge self message
379412 if not self .use_gated_mlp or self .only_angle_gated_mlp :
380413 self .edge_self_linear = MLPLayer (
@@ -1330,6 +1363,7 @@ def forward(
13301363 else :
13311364 edge_info = None
13321365 edge_info_ffn = None
1366+ edge_cat_list = None
13331367
13341368 # edge message use dropout
13351369 if self .edge_message_use_dropout :
@@ -1411,6 +1445,27 @@ def forward(
14111445 )
14121446 )
14131447 n_update_list .append (node_edge_update )
1448+
1449+ if self .EN_use_NGA :
1450+ assert self .node_nga_mlp is not None
1451+ assert self .edge_nga_mlp is not None
1452+ assert edge_cat_list is not None
1453+ # nb, nloc, nnei, n_dim
1454+ attention_weights_nga_i = self .edge_nga_mlp (edge_cat_list [2 ])
1455+ attention_weights_nga_i = (attention_weights_nga_i + 20.0 ) * sw .unsqueeze (
1456+ - 1
1457+ ) - 20.0
1458+ attention_weights_nga_i = torch .softmax (attention_weights_nga_i , dim = - 2 )
1459+ # nb, nloc, nnei, n_dim
1460+ attention_value_nga = self .node_nga_mlp (
1461+ torch .cat (edge_cat_list [:2 ], dim = - 1 )
1462+ )
1463+
1464+ # updated value
1465+ # nb, nloc, n_dim
1466+ update_node_nga = (attention_weights_nga_i * attention_value_nga ).sum (- 2 )
1467+ n_update_list .append (update_node_nga )
1468+
14141469 # update node_ebd
14151470 n_updated = self .list_update (n_update_list , "node" )
14161471
0 commit comments