@@ -87,6 +87,7 @@ def __init__(
8787 use_slim_message : bool = False ,
8888 use_gated_mlp : bool = False ,
8989 gated_mlp_norm : str = "none" ,
90+ only_angle_gated_mlp : bool = False ,
9091 node_use_rmsnorm : bool = False ,
9192 angle_use_node : bool = True ,
9293 activation_function : str = "silu" ,
@@ -177,6 +178,7 @@ def __init__(
177178
178179 self .use_gated_mlp = use_gated_mlp
179180 self .gated_mlp_norm = gated_mlp_norm
181+ self .only_angle_gated_mlp = only_angle_gated_mlp
180182 if self .use_gated_mlp :
181183 assert not self .optim_update , "Gated MLP does not support optim update!"
182184 self .node_use_rmsnorm = node_use_rmsnorm
@@ -265,7 +267,7 @@ def __init__(
265267 residual_idx += 1
266268
267269 # node edge message
268- if not self .use_gated_mlp :
270+ if not self .use_gated_mlp or self . only_angle_gated_mlp :
269271 self .node_edge_linear = MLPLayer (
270272 self .edge_info_dim
271273 if not self .use_ffn_node_edge_message
@@ -309,7 +311,7 @@ def __init__(
309311 residual_idx += 1
310312
311313 # edge self message
312- if not self .use_gated_mlp :
314+ if not self .use_gated_mlp or self . only_angle_gated_mlp :
313315 self .edge_self_linear = MLPLayer (
314316 self .edge_info_dim
315317 if not self .use_ffn_edge_edge_message
@@ -1248,7 +1250,7 @@ def forward(
12481250 if not self .optim_update :
12491251 assert edge_info is not None
12501252 if not self .use_ffn_node_edge_message :
1251- if not self .use_gated_mlp :
1253+ if not self .use_gated_mlp or self . only_angle_gated_mlp :
12521254 node_edge_update = self .act (
12531255 self .node_edge_linear (edge_info )
12541256 ) * sw .unsqueeze (- 1 )
@@ -1320,7 +1322,7 @@ def forward(
13201322 if not self .optim_update :
13211323 assert edge_info is not None
13221324 if not self .use_ffn_edge_edge_message :
1323- if not self .use_gated_mlp :
1325+ if not self .use_gated_mlp or self . only_angle_gated_mlp :
13241326 edge_self_update = self .act (self .edge_self_linear (edge_info ))
13251327 else :
13261328 edge_self_update = self .edge_self_linear (edge_info )
0 commit comments