Skip to content

Commit 1079027

Browse files
committed
feat(pt): add use_res_gnn
1 parent d70730b commit 1079027

4 files changed

Lines changed: 32 additions & 0 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__(
6868
force_embedding_on_edge: bool = False,
6969
use_gated_mlp: bool = False,
7070
gated_mlp_norm: str = "none",
71+
use_res_gnn: bool = False,
72+
res_gnn_layer: int = 6,
7173
) -> None:
7274
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
7375
@@ -185,6 +187,8 @@ def __init__(
185187
self.force_embedding_on_edge = force_embedding_on_edge
186188
self.use_gated_mlp = use_gated_mlp
187189
self.gated_mlp_norm = gated_mlp_norm
190+
self.use_res_gnn = use_res_gnn
191+
self.res_gnn_layer = res_gnn_layer
188192

189193
def __getitem__(self, key):
190194
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def init_subclass_params(sub_data, sub_class):
205205
force_embedding_on_edge=self.repflow_args.force_embedding_on_edge,
206206
use_gated_mlp=self.repflow_args.use_gated_mlp,
207207
gated_mlp_norm=self.repflow_args.gated_mlp_norm,
208+
use_res_gnn=self.repflow_args.use_res_gnn,
209+
res_gnn_layer=self.repflow_args.res_gnn_layer,
208210
use_loc_mapping=use_loc_mapping,
209211
exclude_types=exclude_types,
210212
env_protection=env_protection,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def __init__(
148148
force_embedding_on_edge: bool = False,
149149
use_gated_mlp: bool = False,
150150
gated_mlp_norm: str = "none",
151+
use_res_gnn: bool = False,
152+
res_gnn_layer: int = 6,
151153
use_loc_mapping: bool = True,
152154
optim_update: bool = True,
153155
seed: Optional[Union[int, list[int]]] = None,
@@ -337,6 +339,12 @@ def __init__(
337339
self.use_loc_mapping = use_loc_mapping
338340
self.use_gated_mlp = use_gated_mlp
339341
self.gated_mlp_norm = gated_mlp_norm
342+
self.use_res_gnn = use_res_gnn
343+
self.res_gnn_layer = res_gnn_layer
344+
if self.use_res_gnn:
345+
assert (
346+
self.nlayers % self.res_gnn_layer == 0
347+
), "nlayers must be divisible by res_gnn_layer"
340348
assert not (
341349
self.message_use_self_concat and self.use_slim_message
342350
), "only one of message_use_self_concat and use_slim_message can be True"
@@ -953,6 +961,7 @@ def forward(
953961
mapping = (
954962
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
955963
)
964+
res_node_list = []
956965
for idx, ll in enumerate(self.layers):
957966
# node_ebd: nb x nloc x n_dim
958967
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
@@ -1042,6 +1051,11 @@ def forward(
10421051
d_sw=d_sw,
10431052
rbf_ebd=rbf_ebd,
10441053
)
1054+
if self.use_res_gnn and (idx + 1) % self.res_gnn_layer == 0:
1055+
res_node_list.append(node_ebd.unsqueeze(-1))
1056+
1057+
if self.use_res_gnn:
1058+
node_ebd = torch.concat(res_node_list, dim=-1).mean(dim=-1)
10451059

10461060
if self.use_combined_output:
10471061
concat_list = [node_ebd]

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,18 @@ def dpa3_repflow_args():
18211821
optional=True,
18221822
default="none",
18231823
),
1824+
Argument(
1825+
"use_res_gnn",
1826+
bool,
1827+
optional=True,
1828+
default=False,
1829+
),
1830+
Argument(
1831+
"res_gnn_layer",
1832+
int,
1833+
optional=True,
1834+
default=6,
1835+
),
18241836
]
18251837

18261838

0 commit comments

Comments
 (0)