Skip to content

Commit 879e734

Browse files
committed
pt(feat): add rbf concat mess
1 parent 6ba2399 commit 879e734

5 files changed

Lines changed: 63 additions & 23 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
edge_attn_use_ln: bool = True,
5858
edge_rbf_dot_self: bool = False,
5959
edge_rbf_dot_message: bool = False,
60+
edge_rbf_cat_message: bool = False,
6061
edge_use_esen_rbf: bool = False,
6162
edge_use_esen_atom_ebd: bool = False,
6263
edge_use_esen_env: bool = False,
@@ -208,6 +209,7 @@ def __init__(
208209
self.angle_self_attention = angle_self_attention
209210
self.angle_self_attention_gate = angle_self_attention_gate
210211
self.rmsnorm_mode = rmsnorm_mode
212+
self.edge_rbf_cat_message = edge_rbf_cat_message
211213
assert (
212214
fix_stat_std == 0.3
213215
), "fix_stat_std is not implemented in this version, please use skip_stat instead."

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def init_subclass_params(sub_data, sub_class):
217217
angle_self_attention=self.repflow_args.angle_self_attention,
218218
angle_self_attention_gate=self.repflow_args.angle_self_attention_gate,
219219
rmsnorm_mode=self.repflow_args.rmsnorm_mode,
220+
edge_rbf_cat_message=self.repflow_args.edge_rbf_cat_message,
220221
exclude_types=exclude_types,
221222
env_protection=env_protection,
222223
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
angle_self_attention: bool = False,
9494
angle_self_attention_gate: str = "none",
9595
rmsnorm_mode: str = "none",
96+
edge_rbf_cat_message: bool = False,
9697
activation_function: str = "silu",
9798
update_style: str = "res_residual",
9899
update_residual: float = 0.1,
@@ -168,6 +169,7 @@ def __init__(
168169
self.edge_attn_use_ln = edge_attn_use_ln
169170
self.edge_rbf_dot_self = edge_rbf_dot_self
170171
self.edge_rbf_dot_message = edge_rbf_dot_message
172+
self.edge_rbf_cat_message = edge_rbf_cat_message
171173
self.rbf_dim = rbf_dim
172174
self.residual_pref = residual_pref
173175
self.residual_pref += [1.0] * 10
@@ -228,7 +230,11 @@ def __init__(
228230
self.angle_self_attention = angle_self_attention
229231
self.angle_self_attention_gate = angle_self_attention_gate
230232

231-
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
233+
if (
234+
self.edge_rbf_dot_self
235+
or self.edge_rbf_dot_message
236+
or self.edge_rbf_cat_message
237+
):
232238
self.rbf_mlp = MLPLayer(
233239
rbf_dim,
234240
self.e_dim,
@@ -264,7 +270,11 @@ def __init__(
264270
self.e_residual = []
265271
self.a_residual = []
266272
self.d_residual = []
267-
self.edge_info_dim = self.n_dim * 2 + self.e_dim
273+
self.edge_info_dim = (
274+
self.n_dim * 2 + self.e_dim
275+
if not self.edge_rbf_cat_message
276+
else self.n_dim * 2 + self.e_dim * 2
277+
)
268278

269279
# node self mlp
270280
self.node_self_mlp = MLPLayer(
@@ -1198,7 +1208,11 @@ def forward(
11981208
)
11991209

12001210
# handle edge rbf
1201-
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
1211+
if (
1212+
self.edge_rbf_dot_self
1213+
or self.edge_rbf_dot_message
1214+
or self.edge_rbf_cat_message
1215+
):
12021216
assert rbf_ebd is not None
12031217
assert self.rbf_mlp is not None
12041218
edge_rbf = self.rbf_mlp(rbf_ebd)
@@ -1272,26 +1286,26 @@ def forward(
12721286
if not self.optim_update:
12731287
if not self.use_dynamic_sel:
12741288
# nb x nloc x nnei x (n_dim * 2 + e_dim)
1275-
edge_info = torch.cat(
1276-
[
1277-
torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]),
1278-
nei_node_ebd,
1279-
edge_ebd,
1280-
],
1281-
dim=-1,
1282-
)
1289+
edge_cat_list = [
1290+
torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]),
1291+
nei_node_ebd,
1292+
edge_ebd,
1293+
]
1294+
if self.edge_rbf_cat_message:
1295+
assert edge_rbf is not None
1296+
edge_cat_list += [edge_rbf]
1297+
edge_info = torch.cat(edge_cat_list, dim=-1)
12831298
else:
12841299
# n_edge x (n_dim * 2 + e_dim)
1285-
edge_info = torch.cat(
1286-
[
1287-
torch.index_select(
1288-
node_ebd.reshape(-1, self.n_dim), 0, n2e_index
1289-
),
1290-
nei_node_ebd,
1291-
edge_ebd,
1292-
],
1293-
dim=-1,
1294-
)
1300+
edge_cat_list = [
1301+
torch.index_select(node_ebd.reshape(-1, self.n_dim), 0, n2e_index),
1302+
nei_node_ebd,
1303+
edge_ebd,
1304+
]
1305+
if self.edge_rbf_cat_message:
1306+
assert edge_rbf is not None
1307+
edge_cat_list += [edge_rbf]
1308+
edge_info = torch.cat(edge_cat_list, dim=-1)
12951309
if self.use_ffn_node_edge_message or self.use_ffn_edge_edge_message:
12961310
assert self.edge_message_ffn1 is not None
12971311
edge_info_ffn = self.act(self.edge_message_ffn1(edge_info))

deepmd/pt/model/descriptor/repflows.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(
161161
angle_self_attention: bool = False,
162162
angle_self_attention_gate: str = "none",
163163
rmsnorm_mode: str = "none",
164+
edge_rbf_cat_message: bool = False,
164165
seed: Optional[Union[int, list[int]]] = None,
165166
) -> None:
166167
r"""
@@ -314,6 +315,7 @@ def __init__(
314315
self.edge_use_esen_rbf = edge_use_esen_rbf
315316
self.edge_use_esen_atom_ebd = edge_use_esen_atom_ebd
316317
self.edge_use_esen_env = edge_use_esen_env
318+
self.edge_rbf_cat_message = edge_rbf_cat_message
317319
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
318320
assert self.edge_use_rbf or self.edge_use_concat_rbf, "rbf is not used"
319321
self.edge_embed_input_dim = 1
@@ -333,6 +335,9 @@ def __init__(
333335
elif self.edge_use_rbf:
334336
self.rbf = BesselBasis(self.e_rcut)
335337
self.edge_embed_input_dim = self.rbf.num_basis
338+
elif self.edge_rbf_cat_message:
339+
# edge can use dist itself
340+
self.rbf = BesselBasis(self.e_rcut)
336341
else:
337342
self.rbf = None
338343

@@ -379,6 +384,11 @@ def __init__(
379384
not self.optim_update
380385
), "optim_update must be False when angle_use_node is False"
381386

387+
if self.edge_rbf_cat_message:
388+
assert (
389+
not self.optim_update
390+
), "optim_update must be False when edge_rbf_cat_message is True"
391+
382392
if self.edge_use_esen_atom_ebd:
383393
self.source_embedding = torch.nn.Embedding(self.ntypes, self.e_dim)
384394
self.target_embedding = torch.nn.Embedding(self.ntypes, self.e_dim)
@@ -508,7 +518,9 @@ def __init__(
508518
edge_attn_use_ln=self.edge_attn_use_ln,
509519
edge_rbf_dot_self=self.edge_rbf_dot_self,
510520
edge_rbf_dot_message=self.edge_rbf_dot_message,
511-
rbf_dim=self.edge_embed_input_dim,
521+
rbf_dim=self.edge_embed_input_dim
522+
if not self.edge_rbf_cat_message
523+
else self.rbf.num_basis,
512524
residual_pref=self.residual_pref,
513525
message_use_self_concat=self.message_use_self_concat,
514526
use_slim_message=self.use_slim_message,
@@ -520,6 +532,7 @@ def __init__(
520532
angle_self_attention=self.angle_self_attention,
521533
angle_self_attention_gate=self.angle_self_attention_gate,
522534
rmsnorm_mode=self.rmsnorm_mode,
535+
edge_rbf_cat_message=self.edge_rbf_cat_message,
523536
seed=child_seed(child_seed(seed, 1), ii),
524537
)
525538
)
@@ -934,7 +947,11 @@ def forward(
934947
edge_ebd = self.edge_embd(rbf_input)
935948
elif self.edge_use_dist:
936949
edge_ebd = self.edge_embd(edge_input)
937-
rbf_ebd = None
950+
if not self.edge_rbf_cat_message:
951+
rbf_ebd = None
952+
else:
953+
assert self.rbf is not None
954+
rbf_ebd = self.rbf(edge_input)
938955
elif self.edge_use_concat_rbf:
939956
assert self.rbf is not None
940957
rbf_ebd = torch.cat([dmatrix[..., :1], self.rbf(edge_input)], dim=-1)

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,12 @@ def dpa3_repflow_args():
17491749
optional=True,
17501750
default=False,
17511751
),
1752+
Argument(
1753+
"edge_rbf_cat_message",
1754+
bool,
1755+
optional=True,
1756+
default=False,
1757+
),
17521758
Argument(
17531759
"edge_use_esen_rbf",
17541760
bool,

0 commit comments

Comments
 (0)