Skip to content

Commit 6f5e063

Browse files
committed
feat(pt): add angle_use_node
1 parent 91b9e68 commit 6f5e063

5 files changed

Lines changed: 45 additions & 19 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
use_rk_update: bool = False,
7575
rk_order: int = 4,
7676
rk_update_diff_layer: bool = False,
77+
angle_use_node: bool = True,
7778
) -> None:
7879
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
7980
@@ -197,6 +198,7 @@ def __init__(
197198
self.use_rk_update = use_rk_update
198199
self.rk_order = rk_order
199200
self.rk_update_diff_layer = rk_update_diff_layer
201+
self.angle_use_node = angle_use_node
200202

201203
def __getitem__(self, key):
202204
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def init_subclass_params(sub_data, sub_class):
211211
use_rk_update=self.repflow_args.use_rk_update,
212212
rk_order=self.repflow_args.rk_order,
213213
rk_update_diff_layer=self.repflow_args.rk_update_diff_layer,
214+
angle_use_node=self.repflow_args.angle_use_node,
214215
use_loc_mapping=use_loc_mapping,
215216
exclude_types=exclude_types,
216217
env_protection=env_protection,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
use_gated_mlp: bool = False,
8989
gated_mlp_norm: str = "none",
9090
node_use_rmsnorm: bool = False,
91+
angle_use_node: bool = True,
9192
activation_function: str = "silu",
9293
update_style: str = "res_residual",
9394
update_residual: float = 0.1,
@@ -184,6 +185,8 @@ def __init__(
184185
else:
185186
self.node_rmsnorm = None
186187

188+
self.angle_use_node = angle_use_node
189+
187190
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
188191
self.rbf_mlp = MLPLayer(
189192
rbf_dim,
@@ -380,16 +383,22 @@ def __init__(
380383
self.angle_dim = self.a_dim
381384
if self.a_compress_rate == 0:
382385
# angle + node + edge * 2
383-
self.angle_dim += self.n_dim + 2 * self.e_dim
386+
self.angle_dim += (
387+
self.n_dim + 2 * self.e_dim
388+
if self.angle_use_node
389+
else 2 * self.e_dim
390+
)
384391
self.a_compress_n_linear = None
385392
self.a_compress_e_linear = None
386393
self.e_a_compress_dim = e_dim
387394
self.n_a_compress_dim = n_dim
388395
else:
389396
# angle + a_dim/c + a_dim/2c * 2 * e_rate
390-
self.angle_dim += (1 + self.a_compress_e_rate) * (
391-
self.a_dim // self.a_compress_rate
392-
)
397+
self.angle_dim += (
398+
(1 + self.a_compress_e_rate)
399+
if self.angle_use_node
400+
else self.a_compress_e_rate
401+
) * (self.a_dim // self.a_compress_rate)
393402
self.e_a_compress_dim = (
394403
self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate
395404
)
@@ -1383,20 +1392,6 @@ def forward(
13831392
a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0
13841393
)
13851394
if not self.optim_update:
1386-
# nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
1387-
node_for_angle_info = (
1388-
torch.tile(
1389-
node_ebd_for_angle.unsqueeze(2).unsqueeze(2),
1390-
(1, 1, self.a_sel, self.a_sel, 1),
1391-
)
1392-
if not self.use_dynamic_sel
1393-
else torch.index_select(
1394-
node_ebd_for_angle.reshape(-1, self.n_a_compress_dim),
1395-
0,
1396-
n2a_index,
1397-
)
1398-
)
1399-
14001395
# nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim
14011396
edge_for_angle_k = (
14021397
torch.tile(
@@ -1418,7 +1413,21 @@ def forward(
14181413
[edge_for_angle_k, edge_for_angle_j], dim=-1
14191414
)
14201415
angle_info_list = [angle_ebd]
1421-
angle_info_list.append(node_for_angle_info)
1416+
if self.angle_use_node:
1417+
# nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
1418+
node_for_angle_info = (
1419+
torch.tile(
1420+
node_ebd_for_angle.unsqueeze(2).unsqueeze(2),
1421+
(1, 1, self.a_sel, self.a_sel, 1),
1422+
)
1423+
if not self.use_dynamic_sel
1424+
else torch.index_select(
1425+
node_ebd_for_angle.reshape(-1, self.n_a_compress_dim),
1426+
0,
1427+
n2a_index,
1428+
)
1429+
)
1430+
angle_info_list.append(node_for_angle_info)
14221431
angle_info_list.append(edge_for_angle_info)
14231432
# nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c)
14241433
# [OR]

deepmd/pt/model/descriptor/repflows.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(
155155
use_rk_update: bool = False,
156156
rk_order: int = 4,
157157
rk_update_diff_layer: bool = False,
158+
angle_use_node: bool = True,
158159
optim_update: bool = True,
159160
seed: Optional[Union[int, list[int]]] = None,
160161
) -> None:
@@ -367,6 +368,12 @@ def __init__(
367368

368369
self.node_use_rmsnorm = node_use_rmsnorm
369370

371+
self.angle_use_node = angle_use_node
372+
if not self.angle_use_node:
373+
assert (
374+
not self.optim_update
375+
), "optim_update must be False when angle_use_node is False"
376+
370377
if self.edge_use_esen_atom_ebd:
371378
self.source_embedding = torch.nn.Embedding(self.ntypes, self.e_dim)
372379
self.target_embedding = torch.nn.Embedding(self.ntypes, self.e_dim)
@@ -490,6 +497,7 @@ def __init__(
490497
use_gated_mlp=self.use_gated_mlp,
491498
gated_mlp_norm=self.gated_mlp_norm,
492499
node_use_rmsnorm=self.node_use_rmsnorm,
500+
angle_use_node=self.angle_use_node,
493501
seed=child_seed(child_seed(seed, 1), ii),
494502
)
495503
)

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,12 @@ def dpa3_repflow_args():
18571857
optional=True,
18581858
default=False,
18591859
),
1860+
Argument(
1861+
"angle_use_node",
1862+
bool,
1863+
optional=True,
1864+
default=True,
1865+
),
18601866
]
18611867

18621868

0 commit comments

Comments
 (0)