@@ -92,6 +92,7 @@ def __init__(
9292 angle_use_node : bool = True ,
9393 angle_self_attention : bool = False ,
9494 angle_self_attention_gate : str = "none" ,
95+ rmsnorm_mode : str = "none" ,
9596 activation_function : str = "silu" ,
9697 update_style : str = "res_residual" ,
9798 update_residual : float = 0.1 ,
@@ -189,6 +190,40 @@ def __init__(
189190 else :
190191 self .node_rmsnorm = None
191192
193+ # add rms norm debug for each component, can be removed if not necessary
194+ self .rmsnorm_mode = rmsnorm_mode
195+ self .rmsnorm_mod_list = self .rmsnorm_mode .split (":" )
196+ # mode: ['NEM', 'ESM', 'EAM', 'ASM', 'E']
197+ # node edge message
198+ if "NEM" in self .rmsnorm_mod_list :
199+ self .NEM_rmsnorm = RMSNorm (self .n_dim , precision = precision , trainable = True )
200+ else :
201+ self .NEM_rmsnorm = None
202+
203+ # edge self message
204+ if "ESM" in self .rmsnorm_mod_list :
205+ self .ESM_rmsnorm = RMSNorm (self .e_dim , precision = precision , trainable = True )
206+ else :
207+ self .ESM_rmsnorm = None
208+
209+ # edge angle message
210+ if "EAM" in self .rmsnorm_mod_list :
211+ self .EAM_rmsnorm = RMSNorm (self .e_dim , precision = precision , trainable = True )
212+ else :
213+ self .EAM_rmsnorm = None
214+
215+ # angle self message
216+ if "ASM" in self .rmsnorm_mod_list :
217+ self .ASM_rmsnorm = RMSNorm (self .a_dim , precision = precision , trainable = True )
218+ else :
219+ self .ASM_rmsnorm = None
220+
221+ # edge self
222+ if "E" in self .rmsnorm_mod_list :
223+ self .edge_rmsnorm = RMSNorm (self .e_dim , precision = precision , trainable = True )
224+ else :
225+ self .edge_rmsnorm = None
226+
192227 self .angle_use_node = angle_use_node
193228 self .angle_self_attention = angle_self_attention
194229 self .angle_self_attention_gate = angle_self_attention_gate
@@ -1320,6 +1355,10 @@ def forward(
13201355 )
13211356 )
13221357
1358+ if "NEM" in self .rmsnorm_mod_list :
1359+ assert self .NEM_rmsnorm is not None
1360+ node_edge_update = self .NEM_rmsnorm (node_edge_update )
1361+
13231362 if self .n_multi_edge_message > 1 :
13241363 # nb x nloc x h x n_dim
13251364 node_edge_update_mul_head = node_edge_update .view (
@@ -1372,6 +1411,11 @@ def forward(
13721411 if self .edge_rbf_dot_message :
13731412 assert edge_rbf is not None
13741413 edge_self_update = edge_self_update * edge_rbf
1414+
1415+ if "ESM" in self .rmsnorm_mod_list :
1416+ assert self .ESM_rmsnorm is not None
1417+ edge_self_update = self .ESM_rmsnorm (edge_self_update )
1418+
13751419 e_update_list .append (edge_self_update )
13761420
13771421 # edge attention message
@@ -1561,11 +1605,15 @@ def forward(
15611605 )
15621606 if not self .use_slim_message :
15631607 assert self .edge_angle_linear2 is not None
1564- e_update_list . append (
1565- self .act ( self . edge_angle_linear2 (padding_edge_angle_update ) )
1608+ padding_edge_angle_update = self . act (
1609+ self .edge_angle_linear2 (padding_edge_angle_update )
15661610 )
1567- else :
1568- e_update_list .append (padding_edge_angle_update )
1611+
1612+ if "EAM" in self .rmsnorm_mod_list :
1613+ assert self .EAM_rmsnorm is not None
1614+ padding_edge_angle_update = self .EAM_rmsnorm (padding_edge_angle_update )
1615+
1616+ e_update_list .append (padding_edge_angle_update )
15691617 # update edge_ebd
15701618 e_updated = self .list_update (e_update_list , "edge" )
15711619
@@ -1600,6 +1648,10 @@ def forward(
16001648 "angle" ,
16011649 )
16021650 )
1651+ if "ASM" in self .rmsnorm_mod_list :
1652+ assert self .ASM_rmsnorm is not None
1653+ angle_self_update = self .ASM_rmsnorm (angle_self_update )
1654+
16031655 a_update_list .append (angle_self_update )
16041656
16051657 if self .angle_self_attention :
@@ -1828,6 +1880,10 @@ def list_update_res_residual(
18281880 if update_name == "node" and self .node_use_rmsnorm :
18291881 assert self .node_rmsnorm is not None
18301882 uu = self .node_rmsnorm (uu )
1883+
1884+ if update_name == "edge" and "E" in self .rmsnorm_mod_list :
1885+ assert self .edge_rmsnorm is not None
1886+ uu = self .edge_rmsnorm (uu )
18311887 return uu
18321888
18331889 @torch .jit .export
0 commit comments