@@ -58,6 +58,7 @@ def __init__(
5858 use_dynamic_sel : bool = False ,
5959 sel_reduce_factor : float = 10.0 ,
6060 smooth_edge_update : bool = False ,
61+ update_use_layernorm : bool = False ,
6162 activation_function : str = "silu" ,
6263 update_style : str = "res_residual" ,
6364 update_residual : float = 0.1 ,
@@ -97,6 +98,7 @@ def __init__(
9798 self .update_style = update_style
9899 self .update_residual = update_residual
99100 self .update_residual_init = update_residual_init
101+ self .update_use_layernorm = update_use_layernorm
100102 self .a_compress_e_rate = a_compress_e_rate
101103 self .a_compress_use_split = a_compress_use_split
102104 self .precision = precision
@@ -203,6 +205,17 @@ def __init__(
203205 )
204206 )
205207
208+ if self .update_use_layernorm :
209+ self .node_layernorm = torch .nn .LayerNorm (self .n_dim )
210+ self .edge_layernorm = torch .nn .LayerNorm (self .e_dim )
211+ self .angle_layernorm = (
212+ torch .nn .LayerNorm (self .a_dim ) if self .update_angle else None
213+ )
214+ else :
215+ self .node_layernorm = None
216+ self .edge_layernorm = None
217+ self .angle_layernorm = None
218+
206219 if self .update_angle :
207220 self .angle_dim = self .a_dim
208221 if self .a_compress_rate == 0 :
@@ -1133,6 +1146,14 @@ def forward(
11331146
11341147 # update angle_ebd
11351148 a_updated = self .list_update (a_update_list , "angle" )
1149+ if self .update_use_layernorm :
1150+ assert self .node_layernorm is not None
1151+ n_updated = self .node_layernorm (n_updated )
1152+ assert self .edge_layernorm is not None
1153+ e_updated = self .edge_layernorm (e_updated )
1154+ if self .update_angle :
1155+ assert self .angle_layernorm is not None
1156+ a_updated = self .angle_layernorm (a_updated )
11361157 return n_updated , e_updated , a_updated
11371158
11381159 @torch .jit .export
0 commit comments