Skip to content

Commit e333b8c

Browse files
committed
feat: add dot rbf
1 parent e165765 commit e333b8c

5 files changed

Lines changed: 87 additions & 4 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __init__(
5454
edge_attn_hidden: int = 32,
5555
edge_attn_head: int = 4,
5656
edge_attn_use_ln: bool = True,
57+
edge_rbf_dot_self: bool = False,
58+
edge_rbf_dot_message: bool = False,
5759
) -> None:
5860
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
5961
@@ -157,6 +159,8 @@ def __init__(
157159
self.edge_attn_hidden = edge_attn_hidden
158160
self.edge_attn_head = edge_attn_head
159161
self.edge_attn_use_ln = edge_attn_use_ln
162+
self.edge_rbf_dot_self = edge_rbf_dot_self
163+
self.edge_rbf_dot_message = edge_rbf_dot_message
160164

161165
def __getitem__(self, key):
162166
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def init_subclass_params(sub_data, sub_class):
189189
edge_attn_hidden=self.repflow_args.edge_attn_hidden,
190190
edge_attn_head=self.repflow_args.edge_attn_head,
191191
edge_attn_use_ln=self.repflow_args.edge_attn_use_ln,
192+
edge_rbf_dot_self=self.repflow_args.edge_rbf_dot_self,
193+
edge_rbf_dot_message=self.repflow_args.edge_rbf_dot_message,
192194
exclude_types=exclude_types,
193195
env_protection=env_protection,
194196
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def __init__(
7777
edge_attn_hidden: int = 32,
7878
edge_attn_head: int = 4,
7979
edge_attn_use_ln: bool = True,
80+
edge_rbf_dot_self: bool = False,
81+
edge_rbf_dot_message: bool = False,
82+
rbf_dim: int = 8,
8083
activation_function: str = "silu",
8184
update_style: str = "res_residual",
8285
update_residual: float = 0.1,
@@ -152,6 +155,30 @@ def __init__(
152155
self.edge_attn_hidden = edge_attn_hidden
153156
self.edge_attn_head = edge_attn_head
154157
self.edge_attn_use_ln = edge_attn_use_ln
158+
self.edge_rbf_dot_self = edge_rbf_dot_self
159+
self.edge_rbf_dot_message = edge_rbf_dot_message
160+
self.rbf_dim = rbf_dim
161+
162+
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
163+
self.rbf_mlp = MLPLayer(
164+
rbf_dim,
165+
self.e_dim,
166+
precision=precision,
167+
seed=child_seed(seed, 30),
168+
)
169+
else:
170+
self.rbf_mlp = None
171+
172+
if self.edge_rbf_dot_message:
173+
self.rbf_mlp_message = MLPLayer(
174+
rbf_dim,
175+
self.n_dim,
176+
precision=precision,
177+
seed=child_seed(seed, 31),
178+
)
179+
else:
180+
self.rbf_mlp_message = None
181+
155182
if self.edge_use_attn:
156183
assert (
157184
not self.use_dynamic_sel
@@ -889,6 +916,7 @@ def forward(
889916
dihedral_index: Optional[torch.Tensor] = None, # n_dihedral x 2
890917
dihedral_ebd: Optional[torch.Tensor] = None, # n_dihedral x d_dim
891918
d_sw: Optional[torch.Tensor] = None, # n_dihedral
919+
rbf_ebd: Optional[torch.Tensor] = None, # n_edge x num_b
892920
):
893921
"""
894922
Parameters
@@ -962,6 +990,25 @@ def forward(
962990
)
963991
)
964992

993+
# handle edge rbf
994+
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
995+
assert rbf_ebd is not None
996+
assert self.rbf_mlp is not None
997+
edge_rbf = self.rbf_mlp(rbf_ebd)
998+
else:
999+
edge_rbf = None
1000+
1001+
if self.edge_rbf_dot_message:
1002+
assert rbf_ebd is not None
1003+
assert self.rbf_mlp_message is not None
1004+
edge_rbf_node = self.rbf_mlp_message(rbf_ebd)
1005+
else:
1006+
edge_rbf_node = None
1007+
1008+
if self.edge_rbf_dot_self:
1009+
assert edge_rbf is not None
1010+
edge_ebd = edge_ebd * edge_rbf
1011+
9651012
n_update_list: list[torch.Tensor] = [node_ebd]
9661013
e_update_list: list[torch.Tensor] = [edge_ebd]
9671014
a_update_list: list[torch.Tensor] = [angle_ebd]
@@ -1079,6 +1126,9 @@ def forward(
10791126
"node",
10801127
)
10811128
) * sw.unsqueeze(-1)
1129+
if self.edge_rbf_dot_message:
1130+
assert edge_rbf_node is not None
1131+
node_edge_update = node_edge_update * edge_rbf_node
10821132
node_edge_update = (
10831133
(torch.sum(node_edge_update, dim=-2) / self.nnei)
10841134
if not self.use_dynamic_sel
@@ -1132,6 +1182,9 @@ def forward(
11321182
"edge",
11331183
)
11341184
)
1185+
if self.edge_rbf_dot_message:
1186+
assert edge_rbf is not None
1187+
edge_self_update = edge_self_update * edge_rbf
11351188
e_update_list.append(edge_self_update)
11361189

11371190
# edge attention message

deepmd/pt/model/descriptor/repflows.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def __init__(
130130
edge_attn_hidden: int = 32,
131131
edge_attn_head: int = 4,
132132
edge_attn_use_ln: bool = True,
133+
edge_rbf_dot_self: bool = False,
134+
edge_rbf_dot_message: bool = False,
133135
optim_update: bool = True,
134136
seed: Optional[Union[int, list[int]]] = None,
135137
) -> None:
@@ -275,6 +277,10 @@ def __init__(
275277
self.edge_attn_hidden = edge_attn_hidden
276278
self.edge_attn_head = edge_attn_head
277279
self.edge_attn_use_ln = edge_attn_use_ln
280+
self.edge_rbf_dot_self = edge_rbf_dot_self
281+
self.edge_rbf_dot_message = edge_rbf_dot_message
282+
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
283+
assert self.edge_use_rbf or self.edge_use_concat_rbf, "rbf is not used"
278284
self.edge_embed_input_dim = 1
279285
if self.edge_use_concat_rbf:
280286
self.rbf = BesselBasis(self.e_rcut)
@@ -370,6 +376,9 @@ def __init__(
370376
edge_attn_hidden=self.edge_attn_hidden,
371377
edge_attn_head=self.edge_attn_head,
372378
edge_attn_use_ln=self.edge_attn_use_ln,
379+
edge_rbf_dot_self=self.edge_rbf_dot_self,
380+
edge_rbf_dot_message=self.edge_rbf_dot_message,
381+
rbf_dim=self.edge_embed_input_dim,
373382
seed=child_seed(child_seed(seed, 1), ii),
374383
)
375384
)
@@ -687,15 +696,17 @@ def forward(
687696
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
688697
if self.edge_use_dist:
689698
edge_ebd = self.edge_embd(edge_input)
699+
rbf_ebd = None
690700
elif self.edge_use_concat_rbf:
691701
assert self.rbf is not None
692-
edge_ebd = self.edge_embd(
693-
torch.cat([dmatrix[..., :1], self.rbf(edge_input)], dim=-1)
694-
)
702+
rbf_ebd = torch.cat([dmatrix[..., :1], self.rbf(edge_input)], dim=-1)
703+
edge_ebd = self.edge_embd(rbf_ebd)
695704
elif self.edge_use_rbf:
696705
assert self.rbf is not None
697-
edge_ebd = self.edge_embd(self.rbf(edge_input))
706+
rbf_ebd = self.rbf(edge_input)
707+
edge_ebd = self.edge_embd(rbf_ebd)
698708
else:
709+
rbf_ebd = None
699710
edge_ebd = self.act(self.edge_embd(edge_input))
700711

701712
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
@@ -794,6 +805,7 @@ def forward(
794805
dihedral_index=dihedral_index,
795806
dihedral_ebd=dihedral_ebd,
796807
d_sw=d_sw,
808+
rbf_ebd=rbf_ebd,
797809
)
798810

799811
# nb x nloc x 3 x e_dim

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,18 @@ def dpa3_repflow_args():
17241724
optional=True,
17251725
default=True,
17261726
),
1727+
Argument(
1728+
"edge_rbf_dot_self",
1729+
bool,
1730+
optional=True,
1731+
default=False,
1732+
),
1733+
Argument(
1734+
"edge_rbf_dot_message",
1735+
bool,
1736+
optional=True,
1737+
default=False,
1738+
),
17271739
]
17281740

17291741

0 commit comments

Comments
 (0)