Skip to content

Commit b97321d

Browse files
iProzdcaic99
authored andcommitted
feat(pt): add use_loc_mapping
1 parent e9f5058 commit b97321d

7 files changed

Lines changed: 615 additions & 8 deletions

File tree

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,48 @@ def init_subclass_params(sub_data, sub_class):
150150
fix_stat_std=self.repflow_args.fix_stat_std,
151151
optim_update=self.repflow_args.optim_update,
152152
smooth_edge_update=self.repflow_args.smooth_edge_update,
153+
angle_multi_freq=self.repflow_args.angle_multi_freq,
154+
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
155+
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
156+
use_env_envelope=self.repflow_args.use_env_envelope,
157+
use_new_sw=self.repflow_args.use_new_sw,
158+
update_dihedral=self.repflow_args.update_dihedral,
159+
d_dim=self.repflow_args.d_dim,
160+
d_sel=self.repflow_args.d_sel,
161+
d_rcut=self.repflow_args.d_rcut,
162+
d_rcut_smth=self.repflow_args.d_rcut_smth,
163+
use_ffn_node_edge_message=self.repflow_args.use_ffn_node_edge_message,
164+
use_ffn_edge_edge_message=self.repflow_args.use_ffn_edge_edge_message,
165+
use_ffn_edge_angle_message=self.repflow_args.use_ffn_edge_angle_message,
166+
use_ffn_angle_angle_message=self.repflow_args.use_ffn_angle_angle_message,
167+
ffn_hidden_dim=self.repflow_args.ffn_hidden_dim,
168+
edge_use_concat_rbf=self.repflow_args.edge_use_concat_rbf,
169+
edge_use_rbf=self.repflow_args.edge_use_rbf,
170+
edge_use_dist=self.repflow_args.edge_use_dist,
171+
embed_use_bias=self.repflow_args.embed_use_bias,
172+
edge_use_attn=self.repflow_args.edge_use_attn,
173+
edge_attn_hidden=self.repflow_args.edge_attn_hidden,
174+
edge_attn_head=self.repflow_args.edge_attn_head,
175+
edge_attn_use_ln=self.repflow_args.edge_attn_use_ln,
176+
edge_rbf_dot_self=self.repflow_args.edge_rbf_dot_self,
177+
edge_rbf_dot_message=self.repflow_args.edge_rbf_dot_message,
178+
edge_use_esen_rbf=self.repflow_args.edge_use_esen_rbf,
179+
edge_use_esen_atom_ebd=self.repflow_args.edge_use_esen_atom_ebd,
180+
edge_use_esen_env=self.repflow_args.edge_use_esen_env,
181+
residual_pref=self.repflow_args.residual_pref,
182+
tebd_use_act=self.repflow_args.tebd_use_act,
183+
message_use_self_concat=self.repflow_args.message_use_self_concat,
184+
use_slim_message=self.repflow_args.use_slim_message,
185+
use_combined_output=self.repflow_args.use_combined_output,
186+
use_loc_mapping=use_loc_mapping,
153187
exclude_types=exclude_types,
154188
env_protection=env_protection,
155189
precision=precision,
156190
seed=child_seed(seed, 1),
157191
)
158192

159193
self.use_econf_tebd = use_econf_tebd
194+
self.use_loc_mapping = use_loc_mapping
160195
self.use_tebd_bias = use_tebd_bias
161196
self.type_map = type_map
162197
self.tebd_dim = self.repflow_args.n_dim
@@ -466,12 +501,16 @@ def forward(
466501
The smooth switch function. shape: nf x nloc x nnei
467502
468503
"""
504+
parrallel_mode = comm_dict is not None
469505
# cast the input to internal precsion
470506
extended_coord = extended_coord.to(dtype=self.prec)
471507
nframes, nloc, nnei = nlist.shape
472508
nall = extended_coord.view(nframes, -1).shape[1] // 3
473509

474-
node_ebd_ext = self.type_embedding(extended_atype)
510+
if not parrallel_mode and self.use_loc_mapping:
511+
node_ebd_ext = self.type_embedding(extended_atype[:, :nloc])
512+
else:
513+
node_ebd_ext = self.type_embedding(extended_atype)
475514
node_ebd_inp = node_ebd_ext[:, :nloc, :]
476515
# repflows
477516
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def optim_edge_update(
469469

470470
def forward(
471471
self,
472-
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
472+
node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parrallel_mode
473473
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
474474
h2: torch.Tensor, # nf x nloc x nnei x 3
475475
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim

deepmd/pt/model/descriptor/repflows.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ def forward(
377377
mapping: Optional[torch.Tensor] = None,
378378
comm_dict: Optional[dict[str, torch.Tensor]] = None,
379379
):
380-
if comm_dict is None:
380+
parrallel_mode = comm_dict is not None
381+
if not parrallel_mode:
381382
assert mapping is not None
382-
assert extended_atype_embd is not None
383383
nframes, nloc, nnei = nlist.shape
384384
nall = extended_coord.view(nframes, -1).shape[1] // 3
385385
atype = extended_atype[:, :nloc]
@@ -459,18 +459,23 @@ def forward(
459459
# if the a neighbor is real or not is indicated by nlist_mask
460460
nlist[nlist == -1] = 0
461461
# nb x nall x n_dim
462-
if comm_dict is None:
462+
if not parrallel_mode:
463463
assert mapping is not None
464464
mapping = (
465465
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
466466
)
467467
for idx, ll in enumerate(self.layers):
468468
# node_ebd: nb x nloc x n_dim
469-
# node_ebd_ext: nb x nall x n_dim
470-
if comm_dict is None:
469+
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
470+
if not parrallel_mode:
471471
assert mapping is not None
472-
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
472+
node_ebd_ext = (
473+
torch.gather(node_ebd, 1, mapping)
474+
if not self.use_loc_mapping
475+
else node_ebd
476+
)
473477
else:
478+
assert comm_dict is not None
474479
has_spin = "has_spin" in comm_dict
475480
if not has_spin:
476481
n_padding = nall - nloc

0 commit comments

Comments
 (0)