Skip to content

Commit ee9c29b

Browse files
committed
split message
1 parent da1dc99 commit ee9c29b

5 files changed

Lines changed: 112 additions & 36 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def __init__(
4242
d_rcut: float = 2.8,
4343
d_rcut_smth: float = 2.0,
4444
use_ffn_node_edge_message: bool = False,
45+
use_ffn_edge_edge_message: bool = False,
4546
use_ffn_edge_angle_message: bool = False,
47+
use_ffn_angle_angle_message: bool = False,
4648
ffn_hidden_dim: int = 1024,
4749
) -> None:
4850
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
@@ -135,7 +137,9 @@ def __init__(
135137
self.d_rcut = d_rcut
136138
self.d_rcut_smth = d_rcut_smth
137139
self.use_ffn_node_edge_message = use_ffn_node_edge_message
140+
self.use_ffn_edge_edge_message = use_ffn_edge_edge_message
138141
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
142+
self.use_ffn_angle_angle_message = use_ffn_angle_angle_message
139143
self.ffn_hidden_dim = ffn_hidden_dim
140144

141145
def __getitem__(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ def init_subclass_params(sub_data, sub_class):
177177
d_rcut=self.repflow_args.d_rcut,
178178
d_rcut_smth=self.repflow_args.d_rcut_smth,
179179
use_ffn_node_edge_message=self.repflow_args.use_ffn_node_edge_message,
180+
use_ffn_edge_edge_message=self.repflow_args.use_ffn_edge_edge_message,
180181
use_ffn_edge_angle_message=self.repflow_args.use_ffn_edge_angle_message,
182+
use_ffn_angle_angle_message=self.repflow_args.use_ffn_angle_angle_message,
181183
ffn_hidden_dim=self.repflow_args.ffn_hidden_dim,
182184
exclude_types=exclude_types,
183185
env_protection=env_protection,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
get_residual,
1818
)
1919
from deepmd.pt.model.network.mlp import (
20-
FeedForward,
2120
MLPLayer,
2221
)
2322
from deepmd.pt.model.network.utils import (
@@ -65,7 +64,9 @@ def __init__(
6564
d_rcut: float = 2.8,
6665
d_rcut_smth: float = 2.0,
6766
use_ffn_node_edge_message: bool = False,
67+
use_ffn_edge_edge_message: bool = False,
6868
use_ffn_edge_angle_message: bool = False,
69+
use_ffn_angle_angle_message: bool = False,
6970
ffn_hidden_dim: int = 1024,
7071
activation_function: str = "silu",
7172
update_style: str = "res_residual",
@@ -124,9 +125,16 @@ def __init__(
124125
self.d_rcut_smth = d_rcut_smth
125126
self.dynamic_d_sel = (self.d_sel * 4) / self.sel_reduce_factor
126127
self.use_ffn_node_edge_message = use_ffn_node_edge_message
128+
self.use_ffn_edge_edge_message = use_ffn_edge_edge_message
127129
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
130+
self.use_ffn_angle_angle_message = use_ffn_angle_angle_message
128131
self.ffn_hidden_dim = ffn_hidden_dim
129-
if self.use_ffn_node_edge_message or self.use_ffn_edge_angle_message:
132+
if (
133+
self.use_ffn_node_edge_message
134+
or self.use_ffn_edge_edge_message
135+
or self.use_ffn_edge_angle_message
136+
or self.use_ffn_angle_angle_message
137+
):
130138
assert not self.optim_update, "FFN does not support optim update!"
131139

132140
if self.update_dihedral:
@@ -183,20 +191,13 @@ def __init__(
183191
)
184192

185193
# node edge message
186-
self.node_edge_linear = (
187-
MLPLayer(
188-
self.edge_info_dim,
189-
self.n_multi_edge_message * n_dim,
190-
precision=precision,
191-
seed=child_seed(seed, 4),
192-
)
194+
self.node_edge_linear = MLPLayer(
195+
self.edge_info_dim
193196
if not self.use_ffn_node_edge_message
194-
else FeedForward(
195-
self.edge_info_dim,
196-
self.n_multi_edge_message * n_dim,
197-
self.ffn_hidden_dim,
198-
activation_function=self.activation_function,
199-
)
197+
else self.ffn_hidden_dim,
198+
self.n_multi_edge_message * n_dim,
199+
precision=precision,
200+
seed=child_seed(seed, 4),
200201
)
201202
if self.update_style == "res_residual":
202203
for head_index in range(self.n_multi_edge_message):
@@ -212,7 +213,9 @@ def __init__(
212213

213214
# edge self message
214215
self.edge_self_linear = MLPLayer(
215-
self.edge_info_dim,
216+
self.edge_info_dim
217+
if not self.use_ffn_edge_edge_message
218+
else self.ffn_hidden_dim,
216219
e_dim,
217220
precision=precision,
218221
seed=child_seed(seed, 6),
@@ -266,20 +269,13 @@ def __init__(
266269
self.a_compress_e_linear = None
267270

268271
# edge angle message
269-
self.edge_angle_linear1 = (
270-
MLPLayer(
271-
self.angle_dim,
272-
self.e_dim,
273-
precision=precision,
274-
seed=child_seed(seed, 10),
275-
)
272+
self.edge_angle_linear1 = MLPLayer(
273+
self.angle_dim
276274
if not self.use_ffn_edge_angle_message
277-
else FeedForward(
278-
self.angle_dim,
279-
self.e_dim,
280-
self.ffn_hidden_dim,
281-
activation_function=self.activation_function,
282-
)
275+
else self.ffn_hidden_dim,
276+
self.e_dim,
277+
precision=precision,
278+
seed=child_seed(seed, 10),
283279
)
284280
self.edge_angle_linear2 = MLPLayer(
285281
self.e_dim,
@@ -300,7 +296,9 @@ def __init__(
300296

301297
# angle self message
302298
self.angle_self_linear = MLPLayer(
303-
self.angle_dim,
299+
self.angle_dim
300+
if not self.use_ffn_angle_angle_message
301+
else self.ffn_hidden_dim,
304302
self.a_dim,
305303
precision=precision,
306304
seed=child_seed(seed, 13),
@@ -367,6 +365,28 @@ def __init__(
367365
self.angle_dihedral_linear = None
368366
self.dihedral_self_linear = None
369367

368+
if self.use_ffn_node_edge_message or self.use_ffn_edge_edge_message:
369+
self.edge_message_ffn1 = MLPLayer(
370+
self.edge_info_dim,
371+
self.ffn_hidden_dim,
372+
precision=precision,
373+
bias=False,
374+
seed=child_seed(seed, 19),
375+
)
376+
else:
377+
self.edge_message_ffn1 = None
378+
379+
if self.use_ffn_edge_angle_message or self.use_ffn_angle_angle_message:
380+
self.angle_message_ffn1 = MLPLayer(
381+
self.angle_dim,
382+
self.ffn_hidden_dim,
383+
precision=precision,
384+
bias=False,
385+
seed=child_seed(seed, 20),
386+
)
387+
else:
388+
self.angle_message_ffn1 = None
389+
370390
self.n_residual = nn.ParameterList(self.n_residual)
371391
self.e_residual = nn.ParameterList(self.e_residual)
372392
self.a_residual = nn.ParameterList(self.a_residual)
@@ -963,16 +983,28 @@ def forward(
963983
],
964984
dim=-1,
965985
)
986+
if self.use_ffn_node_edge_message or self.use_ffn_edge_edge_message:
987+
assert self.edge_message_ffn1 is not None
988+
edge_info_ffn = self.act(self.edge_message_ffn1(edge_info))
989+
else:
990+
edge_info_ffn = None
966991
else:
967992
edge_info = None
993+
edge_info_ffn = None
968994

969995
# node edge message
970996
# nb x nloc x nnei x (h * n_dim)
971997
if not self.optim_update:
972998
assert edge_info is not None
973-
node_edge_update = self.act(
974-
self.node_edge_linear(edge_info)
975-
) * sw.unsqueeze(-1)
999+
if not self.use_ffn_node_edge_message:
1000+
node_edge_update = self.act(
1001+
self.node_edge_linear(edge_info)
1002+
) * sw.unsqueeze(-1)
1003+
else:
1004+
assert edge_info_ffn is not None
1005+
node_edge_update = self.act(
1006+
self.node_edge_linear(edge_info_ffn)
1007+
) * sw.unsqueeze(-1)
9761008
else:
9771009
node_edge_update = self.act(
9781010
self.optim_edge_update(
@@ -1021,7 +1053,11 @@ def forward(
10211053
# edge self message
10221054
if not self.optim_update:
10231055
assert edge_info is not None
1024-
edge_self_update = self.act(self.edge_self_linear(edge_info))
1056+
if not self.use_ffn_edge_edge_message:
1057+
edge_self_update = self.act(self.edge_self_linear(edge_info))
1058+
else:
1059+
assert edge_info_ffn is not None
1060+
edge_self_update = self.act(self.edge_self_linear(edge_info_ffn))
10251061
else:
10261062
edge_self_update = self.act(
10271063
self.optim_edge_update(
@@ -1111,14 +1147,26 @@ def forward(
11111147
# [OR]
11121148
# n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c)
11131149
angle_info = torch.cat(angle_info_list, dim=-1)
1150+
if self.use_ffn_edge_angle_message or self.use_ffn_angle_angle_message:
1151+
assert self.angle_message_ffn1 is not None
1152+
angle_info_ffn = self.act(self.angle_message_ffn1(angle_info))
1153+
else:
1154+
angle_info_ffn = None
11141155
else:
11151156
angle_info = None
1157+
angle_info_ffn = None
11161158

11171159
# edge angle message
11181160
# nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim
11191161
if not self.optim_update:
11201162
assert angle_info is not None
1121-
edge_angle_update = self.act(self.edge_angle_linear1(angle_info))
1163+
if not self.use_ffn_edge_angle_message:
1164+
edge_angle_update = self.act(self.edge_angle_linear1(angle_info))
1165+
else:
1166+
assert angle_info_ffn is not None
1167+
edge_angle_update = self.act(
1168+
self.edge_angle_linear1(angle_info_ffn)
1169+
)
11221170
else:
11231171
edge_angle_update = self.act(
11241172
self.optim_angle_update(
@@ -1203,7 +1251,11 @@ def forward(
12031251
# nb x nloc x a_nnei x a_nnei x dim_a
12041252
if not self.optim_update:
12051253
assert angle_info is not None
1206-
angle_self_update = self.act(self.angle_self_linear(angle_info))
1254+
if not self.use_ffn_angle_angle_message:
1255+
angle_self_update = self.act(self.angle_self_linear(angle_info))
1256+
else:
1257+
assert angle_info_ffn is not None
1258+
angle_self_update = self.act(self.angle_self_linear(angle_info_ffn))
12071259
else:
12081260
angle_self_update = self.act(
12091261
self.optim_angle_update(

deepmd/pt/model/descriptor/repflows.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def __init__(
117117
d_rcut: float = 2.8,
118118
d_rcut_smth: float = 2.0,
119119
use_ffn_node_edge_message: bool = False,
120+
use_ffn_edge_edge_message: bool = False,
120121
use_ffn_edge_angle_message: bool = False,
122+
use_ffn_angle_angle_message: bool = False,
121123
ffn_hidden_dim: int = 1024,
122124
optim_update: bool = True,
123125
seed: Optional[Union[int, list[int]]] = None,
@@ -252,7 +254,9 @@ def __init__(
252254
self.d_rcut = d_rcut
253255
self.d_rcut_smth = d_rcut_smth
254256
self.use_ffn_node_edge_message = use_ffn_node_edge_message
257+
self.use_ffn_edge_edge_message = use_ffn_edge_edge_message
255258
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
259+
self.use_ffn_angle_angle_message = use_ffn_angle_angle_message
256260
self.ffn_hidden_dim = ffn_hidden_dim
257261

258262
self.n_dim = n_dim
@@ -328,7 +332,9 @@ def __init__(
328332
d_rcut=self.d_rcut,
329333
d_rcut_smth=self.d_rcut_smth,
330334
use_ffn_node_edge_message=self.use_ffn_node_edge_message,
335+
use_ffn_edge_edge_message=self.use_ffn_edge_edge_message,
331336
use_ffn_edge_angle_message=self.use_ffn_edge_angle_message,
337+
use_ffn_angle_angle_message=self.use_ffn_angle_angle_message,
332338
ffn_hidden_dim=self.ffn_hidden_dim,
333339
seed=child_seed(child_seed(seed, 1), ii),
334340
)

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,12 +1652,24 @@ def dpa3_repflow_args():
16521652
optional=True,
16531653
default=False,
16541654
),
1655+
Argument(
1656+
"use_ffn_edge_edge_message",
1657+
bool,
1658+
optional=True,
1659+
default=False,
1660+
),
16551661
Argument(
16561662
"use_ffn_edge_angle_message",
16571663
bool,
16581664
optional=True,
16591665
default=False,
16601666
),
1667+
Argument(
1668+
"use_ffn_angle_angle_message",
1669+
bool,
1670+
optional=True,
1671+
default=False,
1672+
),
16611673
Argument(
16621674
"ffn_hidden_dim",
16631675
int,

0 commit comments

Comments
 (0)