@@ -97,6 +97,7 @@ def __init__(
9797 edge_message_use_dropout : bool = False ,
9898 angle_message_use_dropout : bool = False ,
9999 EN_use_NGA : bool = False ,
100+ AE_use_NGA : bool = False ,
100101 dropout_rate : float = 0.1 ,
101102 activation_function : str = "silu" ,
102103 update_style : str = "res_residual" ,
@@ -148,7 +149,8 @@ def __init__(
148149 self .dynamic_e_sel = self .nnei / self .sel_reduce_factor
149150 self .dynamic_a_sel = self .a_sel / self .sel_reduce_factor
150151 self .EN_use_NGA = EN_use_NGA
151- if self .EN_use_NGA :
152+ self .AE_use_NGA = AE_use_NGA
153+ if self .EN_use_NGA or self .AE_use_NGA :
152154 assert (
153155 not self .use_dynamic_sel and not self .optim_update
154156 ), "NGA does not support dynamic selection or optim update!"
@@ -575,6 +577,42 @@ def __init__(
575577 )
576578 residual_idx += 1
577579
580+ # edge angle NGA
581+ if self .AE_use_NGA :
582+ self .AE_angle_nga_mlp = MLPLayer (
583+ a_dim ,
584+ e_dim ,
585+ precision = precision ,
586+ seed = child_seed (seed , 24 ),
587+ )
588+ self .AE_edge_nga_mlp = MLPLayer (
589+ 2 * self .e_a_compress_dim ,
590+ e_dim ,
591+ precision = precision ,
592+ seed = child_seed (seed , 25 ),
593+ )
594+ self .AE_edge_nga_mlp_out = MLPLayer (
595+ e_dim ,
596+ e_dim ,
597+ precision = precision ,
598+ seed = child_seed (seed , 26 ),
599+ )
600+ if self .update_style == "res_residual" :
601+ self .e_residual .append (
602+ get_residual (
603+ e_dim ,
604+ self .update_residual * self .residual_pref [residual_idx ],
605+ self .update_residual_init ,
606+ precision = precision ,
607+ seed = child_seed (seed , 27 ),
608+ )
609+ )
610+ residual_idx += 1
611+ else :
612+ self .AE_angle_nga_mlp = None
613+ self .AE_edge_nga_mlp = None
614+ self .AE_edge_nga_mlp_out = None
615+
578616 # angle self message
579617 if not self .use_gated_mlp :
580618 self .angle_self_linear = MLPLayer (
@@ -1455,6 +1493,7 @@ def forward(
14551493 )
14561494 n_update_list .append (node_edge_update )
14571495
1496+ # node edge nga
14581497 if self .EN_use_NGA :
14591498 assert self .node_nga_mlp is not None
14601499 assert self .edge_nga_mlp is not None
@@ -1613,6 +1652,7 @@ def forward(
16131652 else :
16141653 angle_info = None
16151654 angle_info_ffn = None
1655+ angle_info_list = None
16161656
16171657 # angle message use dropout
16181658 if self .angle_message_use_dropout :
@@ -1725,6 +1765,44 @@ def forward(
17251765 padding_edge_angle_update = self .EAM_rmsnorm (padding_edge_angle_update )
17261766
17271767 e_update_list .append (padding_edge_angle_update )
1768+
1769+ # edge angle NGA
1770+ if self .AE_use_NGA :
1771+ assert self .AE_angle_nga_mlp is not None
1772+ assert self .AE_edge_nga_mlp is not None
1773+ assert self .AE_edge_nga_mlp_out is not None
1774+ assert angle_info_list is not None
1775+
1776+ # nb, nloc, a_nnei, a_nnei, e_dim
1777+ attention_weights_nga_i = self .AE_angle_nga_mlp (angle_info_list [0 ])
1778+ attention_weights_nga_i = (
1779+ attention_weights_nga_i + 20.0
1780+ ) * a_sw .unsqueeze (- 1 ).unsqueeze (- 1 ) - 20.0
1781+ attention_weights_nga_i = torch .softmax (attention_weights_nga_i , dim = - 2 )
1782+ # nb, nloc, a_nnei, a_nnei, e_dim
1783+ attention_value_nga = self .act (self .AE_edge_nga_mlp (angle_info_list [2 ]))
1784+
1785+ # updated value
1786+ # nb, nloc, a_nnei, e_dim
1787+ reduce_edge_nga = (attention_weights_nga_i * attention_value_nga ).sum (
1788+ - 2
1789+ )
1790+ # nb x nloc x nnei x e_dim
1791+ padding_edge_nga = torch .concat (
1792+ [
1793+ reduce_edge_nga ,
1794+ torch .zeros (
1795+ [nb , nloc , self .nnei - self .a_sel , self .e_dim ],
1796+ dtype = edge_ebd .dtype ,
1797+ device = edge_ebd .device ,
1798+ ),
1799+ ],
1800+ dim = 2 ,
1801+ )
1802+ # nb x nloc x nnei x e_dim
1803+ update_edge_nga = self .act (self .AE_edge_nga_mlp_out (padding_edge_nga ))
1804+ e_update_list .append (update_edge_nga )
1805+
17281806 # update edge_ebd
17291807 e_updated = self .list_update (e_update_list , "edge" )
17301808
0 commit comments