1717 get_residual ,
1818)
1919from deepmd .pt .model .network .mlp import (
20- FeedForward ,
2120 MLPLayer ,
2221)
2322from deepmd .pt .model .network .utils import (
@@ -65,7 +64,9 @@ def __init__(
6564 d_rcut : float = 2.8 ,
6665 d_rcut_smth : float = 2.0 ,
6766 use_ffn_node_edge_message : bool = False ,
67+ use_ffn_edge_edge_message : bool = False ,
6868 use_ffn_edge_angle_message : bool = False ,
69+ use_ffn_angle_angle_message : bool = False ,
6970 ffn_hidden_dim : int = 1024 ,
7071 activation_function : str = "silu" ,
7172 update_style : str = "res_residual" ,
@@ -124,9 +125,16 @@ def __init__(
124125 self .d_rcut_smth = d_rcut_smth
125126 self .dynamic_d_sel = (self .d_sel * 4 ) / self .sel_reduce_factor
126127 self .use_ffn_node_edge_message = use_ffn_node_edge_message
128+ self .use_ffn_edge_edge_message = use_ffn_edge_edge_message
127129 self .use_ffn_edge_angle_message = use_ffn_edge_angle_message
130+ self .use_ffn_angle_angle_message = use_ffn_angle_angle_message
128131 self .ffn_hidden_dim = ffn_hidden_dim
129- if self .use_ffn_node_edge_message or self .use_ffn_edge_angle_message :
132+ if (
133+ self .use_ffn_node_edge_message
134+ or self .use_ffn_edge_edge_message
135+ or self .use_ffn_edge_angle_message
136+ or self .use_ffn_angle_angle_message
137+ ):
130138 assert not self .optim_update , "FFN does not support optim update!"
131139
132140 if self .update_dihedral :
@@ -183,20 +191,13 @@ def __init__(
183191 )
184192
185193 # node edge message
186- self .node_edge_linear = (
187- MLPLayer (
188- self .edge_info_dim ,
189- self .n_multi_edge_message * n_dim ,
190- precision = precision ,
191- seed = child_seed (seed , 4 ),
192- )
194+ self .node_edge_linear = MLPLayer (
195+ self .edge_info_dim
193196 if not self .use_ffn_node_edge_message
194- else FeedForward (
195- self .edge_info_dim ,
196- self .n_multi_edge_message * n_dim ,
197- self .ffn_hidden_dim ,
198- activation_function = self .activation_function ,
199- )
197+ else self .ffn_hidden_dim ,
198+ self .n_multi_edge_message * n_dim ,
199+ precision = precision ,
200+ seed = child_seed (seed , 4 ),
200201 )
201202 if self .update_style == "res_residual" :
202203 for head_index in range (self .n_multi_edge_message ):
@@ -212,7 +213,9 @@ def __init__(
212213
213214 # edge self message
214215 self .edge_self_linear = MLPLayer (
215- self .edge_info_dim ,
216+ self .edge_info_dim
217+ if not self .use_ffn_edge_edge_message
218+ else self .ffn_hidden_dim ,
216219 e_dim ,
217220 precision = precision ,
218221 seed = child_seed (seed , 6 ),
@@ -266,20 +269,13 @@ def __init__(
266269 self .a_compress_e_linear = None
267270
268271 # edge angle message
269- self .edge_angle_linear1 = (
270- MLPLayer (
271- self .angle_dim ,
272- self .e_dim ,
273- precision = precision ,
274- seed = child_seed (seed , 10 ),
275- )
272+ self .edge_angle_linear1 = MLPLayer (
273+ self .angle_dim
276274 if not self .use_ffn_edge_angle_message
277- else FeedForward (
278- self .angle_dim ,
279- self .e_dim ,
280- self .ffn_hidden_dim ,
281- activation_function = self .activation_function ,
282- )
275+ else self .ffn_hidden_dim ,
276+ self .e_dim ,
277+ precision = precision ,
278+ seed = child_seed (seed , 10 ),
283279 )
284280 self .edge_angle_linear2 = MLPLayer (
285281 self .e_dim ,
@@ -300,7 +296,9 @@ def __init__(
300296
301297 # angle self message
302298 self .angle_self_linear = MLPLayer (
303- self .angle_dim ,
299+ self .angle_dim
300+ if not self .use_ffn_angle_angle_message
301+ else self .ffn_hidden_dim ,
304302 self .a_dim ,
305303 precision = precision ,
306304 seed = child_seed (seed , 13 ),
@@ -367,6 +365,28 @@ def __init__(
367365 self .angle_dihedral_linear = None
368366 self .dihedral_self_linear = None
369367
368+ if self .use_ffn_node_edge_message or self .use_ffn_edge_edge_message :
369+ self .edge_message_ffn1 = MLPLayer (
370+ self .edge_info_dim ,
371+ self .ffn_hidden_dim ,
372+ precision = precision ,
373+ bias = False ,
374+ seed = child_seed (seed , 19 ),
375+ )
376+ else :
377+ self .edge_message_ffn1 = None
378+
379+ if self .use_ffn_edge_angle_message or self .use_ffn_angle_angle_message :
380+ self .angle_message_ffn1 = MLPLayer (
381+ self .angle_dim ,
382+ self .ffn_hidden_dim ,
383+ precision = precision ,
384+ bias = False ,
385+ seed = child_seed (seed , 20 ),
386+ )
387+ else :
388+ self .angle_message_ffn1 = None
389+
370390 self .n_residual = nn .ParameterList (self .n_residual )
371391 self .e_residual = nn .ParameterList (self .e_residual )
372392 self .a_residual = nn .ParameterList (self .a_residual )
@@ -963,16 +983,28 @@ def forward(
963983 ],
964984 dim = - 1 ,
965985 )
986+ if self .use_ffn_node_edge_message or self .use_ffn_edge_edge_message :
987+ assert self .edge_message_ffn1 is not None
988+ edge_info_ffn = self .act (self .edge_message_ffn1 (edge_info ))
989+ else :
990+ edge_info_ffn = None
966991 else :
967992 edge_info = None
993+ edge_info_ffn = None
968994
969995 # node edge message
970996 # nb x nloc x nnei x (h * n_dim)
971997 if not self .optim_update :
972998 assert edge_info is not None
973- node_edge_update = self .act (
974- self .node_edge_linear (edge_info )
975- ) * sw .unsqueeze (- 1 )
999+ if not self .use_ffn_node_edge_message :
1000+ node_edge_update = self .act (
1001+ self .node_edge_linear (edge_info )
1002+ ) * sw .unsqueeze (- 1 )
1003+ else :
1004+ assert edge_info_ffn is not None
1005+ node_edge_update = self .act (
1006+ self .node_edge_linear (edge_info_ffn )
1007+ ) * sw .unsqueeze (- 1 )
9761008 else :
9771009 node_edge_update = self .act (
9781010 self .optim_edge_update (
@@ -1021,7 +1053,11 @@ def forward(
10211053 # edge self message
10221054 if not self .optim_update :
10231055 assert edge_info is not None
1024- edge_self_update = self .act (self .edge_self_linear (edge_info ))
1056+ if not self .use_ffn_edge_edge_message :
1057+ edge_self_update = self .act (self .edge_self_linear (edge_info ))
1058+ else :
1059+ assert edge_info_ffn is not None
1060+ edge_self_update = self .act (self .edge_self_linear (edge_info_ffn ))
10251061 else :
10261062 edge_self_update = self .act (
10271063 self .optim_edge_update (
@@ -1111,14 +1147,26 @@ def forward(
11111147 # [OR]
11121148 # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c)
11131149 angle_info = torch .cat (angle_info_list , dim = - 1 )
1150+ if self .use_ffn_edge_angle_message or self .use_ffn_angle_angle_message :
1151+ assert self .angle_message_ffn1 is not None
1152+ angle_info_ffn = self .act (self .angle_message_ffn1 (angle_info ))
1153+ else :
1154+ angle_info_ffn = None
11141155 else :
11151156 angle_info = None
1157+ angle_info_ffn = None
11161158
11171159 # edge angle message
11181160 # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim
11191161 if not self .optim_update :
11201162 assert angle_info is not None
1121- edge_angle_update = self .act (self .edge_angle_linear1 (angle_info ))
1163+ if not self .use_ffn_edge_angle_message :
1164+ edge_angle_update = self .act (self .edge_angle_linear1 (angle_info ))
1165+ else :
1166+ assert angle_info_ffn is not None
1167+ edge_angle_update = self .act (
1168+ self .edge_angle_linear1 (angle_info_ffn )
1169+ )
11221170 else :
11231171 edge_angle_update = self .act (
11241172 self .optim_angle_update (
@@ -1203,7 +1251,11 @@ def forward(
12031251 # nb x nloc x a_nnei x a_nnei x dim_a
12041252 if not self .optim_update :
12051253 assert angle_info is not None
1206- angle_self_update = self .act (self .angle_self_linear (angle_info ))
1254+ if not self .use_ffn_angle_angle_message :
1255+ angle_self_update = self .act (self .angle_self_linear (angle_info ))
1256+ else :
1257+ assert angle_info_ffn is not None
1258+ angle_self_update = self .act (self .angle_self_linear (angle_info_ffn ))
12071259 else :
12081260 angle_self_update = self .act (
12091261 self .optim_angle_update (
0 commit comments