Skip to content

Commit 563b2b2

Browse files
committed
feat(pt): add only_angle_gated_mlp
1 parent 6f5e063 commit 563b2b2

5 files changed

Lines changed: 18 additions & 4 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
force_embedding_on_edge: bool = False,
6969
use_gated_mlp: bool = False,
7070
gated_mlp_norm: str = "none",
71+
only_angle_gated_mlp: bool = False,
7172
use_res_gnn: bool = False,
7273
res_gnn_layer: int = 6,
7374
node_use_rmsnorm: bool = False,
@@ -199,6 +200,7 @@ def __init__(
199200
self.rk_order = rk_order
200201
self.rk_update_diff_layer = rk_update_diff_layer
201202
self.angle_use_node = angle_use_node
203+
self.only_angle_gated_mlp = only_angle_gated_mlp
202204

203205
def __getitem__(self, key):
204206
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def init_subclass_params(sub_data, sub_class):
205205
force_embedding_on_edge=self.repflow_args.force_embedding_on_edge,
206206
use_gated_mlp=self.repflow_args.use_gated_mlp,
207207
gated_mlp_norm=self.repflow_args.gated_mlp_norm,
208+
only_angle_gated_mlp=self.repflow_args.only_angle_gated_mlp,
208209
node_use_rmsnorm=self.repflow_args.node_use_rmsnorm,
209210
use_res_gnn=self.repflow_args.use_res_gnn,
210211
res_gnn_layer=self.repflow_args.res_gnn_layer,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
use_slim_message: bool = False,
8888
use_gated_mlp: bool = False,
8989
gated_mlp_norm: str = "none",
90+
only_angle_gated_mlp: bool = False,
9091
node_use_rmsnorm: bool = False,
9192
angle_use_node: bool = True,
9293
activation_function: str = "silu",
@@ -177,6 +178,7 @@ def __init__(
177178

178179
self.use_gated_mlp = use_gated_mlp
179180
self.gated_mlp_norm = gated_mlp_norm
181+
self.only_angle_gated_mlp = only_angle_gated_mlp
180182
if self.use_gated_mlp:
181183
assert not self.optim_update, "Gated MLP does not support optim update!"
182184
self.node_use_rmsnorm = node_use_rmsnorm
@@ -265,7 +267,7 @@ def __init__(
265267
residual_idx += 1
266268

267269
# node edge message
268-
if not self.use_gated_mlp:
270+
if not self.use_gated_mlp or self.only_angle_gated_mlp:
269271
self.node_edge_linear = MLPLayer(
270272
self.edge_info_dim
271273
if not self.use_ffn_node_edge_message
@@ -309,7 +311,7 @@ def __init__(
309311
residual_idx += 1
310312

311313
# edge self message
312-
if not self.use_gated_mlp:
314+
if not self.use_gated_mlp or self.only_angle_gated_mlp:
313315
self.edge_self_linear = MLPLayer(
314316
self.edge_info_dim
315317
if not self.use_ffn_edge_edge_message
@@ -1248,7 +1250,7 @@ def forward(
12481250
if not self.optim_update:
12491251
assert edge_info is not None
12501252
if not self.use_ffn_node_edge_message:
1251-
if not self.use_gated_mlp:
1253+
if not self.use_gated_mlp or self.only_angle_gated_mlp:
12521254
node_edge_update = self.act(
12531255
self.node_edge_linear(edge_info)
12541256
) * sw.unsqueeze(-1)
@@ -1320,7 +1322,7 @@ def forward(
13201322
if not self.optim_update:
13211323
assert edge_info is not None
13221324
if not self.use_ffn_edge_edge_message:
1323-
if not self.use_gated_mlp:
1325+
if not self.use_gated_mlp or self.only_angle_gated_mlp:
13241326
edge_self_update = self.act(self.edge_self_linear(edge_info))
13251327
else:
13261328
edge_self_update = self.edge_self_linear(edge_info)

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
force_embedding_on_edge: bool = False,
149149
use_gated_mlp: bool = False,
150150
gated_mlp_norm: str = "none",
151+
only_angle_gated_mlp: bool = False,
151152
use_res_gnn: bool = False,
152153
res_gnn_layer: int = 6,
153154
node_use_rmsnorm: bool = False,
@@ -356,6 +357,7 @@ def __init__(
356357
assert self.rk_order == 4, "rk_order must be 4 for now"
357358
self.use_gated_mlp = use_gated_mlp
358359
self.gated_mlp_norm = gated_mlp_norm
360+
self.only_angle_gated_mlp = only_angle_gated_mlp
359361
self.use_res_gnn = use_res_gnn
360362
self.res_gnn_layer = res_gnn_layer
361363
if self.use_res_gnn:
@@ -496,6 +498,7 @@ def __init__(
496498
use_slim_message=self.use_slim_message,
497499
use_gated_mlp=self.use_gated_mlp,
498500
gated_mlp_norm=self.gated_mlp_norm,
501+
only_angle_gated_mlp=self.only_angle_gated_mlp,
499502
node_use_rmsnorm=self.node_use_rmsnorm,
500503
angle_use_node=self.angle_use_node,
501504
seed=child_seed(child_seed(seed, 1), ii),

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,12 @@ def dpa3_repflow_args():
18211821
optional=True,
18221822
default="none",
18231823
),
1824+
Argument(
1825+
"only_angle_gated_mlp",
1826+
bool,
1827+
optional=True,
1828+
default=False,
1829+
),
18241830
Argument(
18251831
"use_res_gnn",
18261832
bool,

0 commit comments

Comments
 (0)