Skip to content

Commit e208538

Browse files
iProzdcaic99
authored andcommitted
feat(pt): add use_loc_mapping
1 parent 75b175b commit e208538

7 files changed

Lines changed: 304 additions & 16 deletions

File tree

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,15 @@ def init_subclass_params(sub_data, sub_class):
153153
use_exp_switch=self.repflow_args.use_exp_switch,
154154
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
155155
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
156+
use_loc_mapping=use_loc_mapping,
156157
exclude_types=exclude_types,
157158
env_protection=env_protection,
158159
precision=precision,
159160
seed=child_seed(seed, 1),
160161
)
161162

162163
self.use_econf_tebd = use_econf_tebd
164+
self.use_loc_mapping = use_loc_mapping
163165
self.use_tebd_bias = use_tebd_bias
164166
self.type_map = type_map
165167
self.tebd_dim = self.repflow_args.n_dim
@@ -469,12 +471,16 @@ def forward(
469471
The smooth switch function. shape: nf x nloc x nnei
470472
471473
"""
474+
parrallel_mode = comm_dict is not None
472475
# cast the input to internal precsion
473476
extended_coord = extended_coord.to(dtype=self.prec)
474477
nframes, nloc, nnei = nlist.shape
475478
nall = extended_coord.view(nframes, -1).shape[1] // 3
476479

477-
node_ebd_ext = self.type_embedding(extended_atype)
480+
if not parrallel_mode and self.use_loc_mapping:
481+
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
482+
else:
483+
node_ebd_ext = self.type_embedding(extended_atype)
478484
node_ebd_inp = node_ebd_ext[:, :nloc, :]
479485
# repflows
480486
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def optim_edge_update_dynamic(
684684

685685
def forward(
686686
self,
687-
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
687+
node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parrallel_mode
688688
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
689689
h2: torch.Tensor, # nf x nloc x nnei x 3
690690
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim

deepmd/pt/model/descriptor/repflows.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def __init__(
209209
use_exp_switch: bool = False,
210210
use_dynamic_sel: bool = False,
211211
sel_reduce_factor: float = 10.0,
212+
use_loc_mapping: bool = True,
212213
optim_update: bool = True,
213214
seed: Optional[Union[int, list[int]]] = None,
214215
) -> None:
@@ -416,9 +417,9 @@ def forward(
416417
mapping: Optional[torch.Tensor] = None,
417418
comm_dict: Optional[dict[str, torch.Tensor]] = None,
418419
):
419-
if comm_dict is None:
420+
parrallel_mode = comm_dict is not None
421+
if not parrallel_mode:
420422
assert mapping is not None
421-
assert extended_atype_embd is not None
422423
nframes, nloc, nnei = nlist.shape
423424
nall = extended_coord.view(nframes, -1).shape[1] // 3
424425
atype = extended_atype[:, :nloc]
@@ -470,12 +471,9 @@ def forward(
470471

471472
# get node embedding
472473
# [nframes, nloc, tebd_dim]
473-
if comm_dict is None:
474-
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
475-
atype_embd = extended_atype_embd[:, :nloc, :]
476-
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
477-
else:
478-
atype_embd = extended_atype_embd
474+
assert extended_atype_embd is not None
475+
atype_embd = extended_atype_embd[:, :nloc, :]
476+
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
479477
assert isinstance(atype_embd, torch.Tensor) # for jit
480478
node_ebd = self.act(atype_embd)
481479
n_dim = node_ebd.shape[-1]
@@ -494,10 +492,19 @@ def forward(
494492
cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6)
495493
angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)
496494

495+
if not parrallel_mode and self.use_loc_mapping:
496+
assert mapping is not None
497+
# convert nlist from nall to nloc index
498+
nlist = torch.gather(
499+
mapping,
500+
1,
501+
index=nlist.reshape(nframes, -1),
502+
).reshape(nlist.shape)
497503
if self.use_dynamic_sel:
498504
# get graph index
499505
edge_index, angle_index = get_graph_index(
500-
nlist, nlist_mask, a_nlist_mask, nall
506+
nlist, nlist_mask, a_nlist_mask, nall,
507+
use_loc_mapping=self.use_loc_mapping,
501508
)
502509
# flat all the tensors
503510
# n_edge x 1
@@ -524,18 +531,23 @@ def forward(
524531
angle_ebd = self.angle_embd(angle_input)
525532

526533
# nb x nall x n_dim
527-
if comm_dict is None:
534+
if not parrallel_mode:
528535
assert mapping is not None
529536
mapping = (
530537
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
531538
)
532539
for idx, ll in enumerate(self.layers):
533540
# node_ebd: nb x nloc x n_dim
534-
# node_ebd_ext: nb x nall x n_dim
535-
if comm_dict is None:
541+
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
542+
if not parrallel_mode:
536543
assert mapping is not None
537-
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
544+
node_ebd_ext = (
545+
torch.gather(node_ebd, 1, mapping)
546+
if not self.use_loc_mapping
547+
else node_ebd
548+
)
538549
else:
550+
assert comm_dict is not None
539551
has_spin = "has_spin" in comm_dict
540552
if not has_spin:
541553
n_padding = nall - nloc

deepmd/pt/model/network/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def get_graph_index(
5151
nlist_mask: torch.Tensor,
5252
a_nlist_mask: torch.Tensor,
5353
nall: int,
54+
use_loc_mapping: bool = True,
5455
):
5556
"""
5657
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
@@ -100,7 +101,9 @@ def get_graph_index(
100101
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]
101102

102103
# node_ext(j) to edge(ij) index_select
103-
frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * nall
104+
frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * (
105+
nall if not use_loc_mapping else nloc
106+
)
104107
shifted_nlist = nlist + frame_shift[:, None, None]
105108
# n_edge
106109
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,12 @@ def descrpt_dpa3_args():
14211421
default=False,
14221422
doc=doc_use_tebd_bias,
14231423
),
1424+
Argument(
1425+
"use_loc_mapping",
1426+
bool,
1427+
optional=True,
1428+
default=True,
1429+
),
14241430
]
14251431

14261432

0 commit comments

Comments
 (0)