@@ -395,6 +395,12 @@ def __init__(
395395 precision = precision ,
396396 seed = child_seed (seed , 21 ),
397397 )
398+ self .node_nga_mlp_out = MLPLayer (
399+ n_dim ,
400+ n_dim ,
401+ precision = precision ,
402+ seed = child_seed (seed , 22 ),
403+ )
398404 if self .update_style == "res_residual" :
399405 self .n_residual .append (
400406 get_residual (
@@ -409,6 +415,7 @@ def __init__(
409415 else :
410416 self .edge_nga_mlp = None
411417 self .node_nga_mlp = None
418+ self .node_nga_mlp_out = None
412419
413420 # edge self message
414421 if not self .use_gated_mlp or self .only_angle_gated_mlp :
@@ -1451,6 +1458,7 @@ def forward(
14511458 if self .EN_use_NGA :
14521459 assert self .node_nga_mlp is not None
14531460 assert self .edge_nga_mlp is not None
1461+ assert self .node_nga_mlp_out is not None
14541462 assert edge_cat_list is not None
14551463 # nb, nloc, nnei, n_dim
14561464 attention_weights_nga_i = self .edge_nga_mlp (edge_cat_list [2 ])
@@ -1459,13 +1467,17 @@ def forward(
14591467 ) - 20.0
14601468 attention_weights_nga_i = torch .softmax (attention_weights_nga_i , dim = - 2 )
14611469 # nb, nloc, nnei, n_dim
1462- attention_value_nga = self .node_nga_mlp (
1463- torch .cat (edge_cat_list [:2 ], dim = - 1 )
1470+ attention_value_nga = self .act (
1471+ self . node_nga_mlp ( torch .cat (edge_cat_list [:2 ], dim = - 1 ) )
14641472 )
14651473
14661474 # updated value
14671475 # nb, nloc, n_dim
1468- update_node_nga = (attention_weights_nga_i * attention_value_nga ).sum (- 2 )
1476+ update_node_nga = self .act (
1477+ self .node_nga_mlp_out (
1478+ (attention_weights_nga_i * attention_value_nga ).sum (- 2 )
1479+ )
1480+ )
14691481 n_update_list .append (update_node_nga )
14701482
14711483 # update node_ebd
0 commit comments