Skip to content

Commit 1888cd2

Browse files
committed
remove node_ebd_ext
1 parent 09564f3 commit 1888cd2

2 files changed

Lines changed: 14 additions & 12 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ def optim_edge_update(
466466

467467
def forward(
468468
self,
469-
node_ebd_ext: torch.Tensor, # nf x nall x n_dim
469+
node_ebd: torch.Tensor, # nf x nloc x n_dim
470+
node_ebd_ext: Optional[torch.Tensor], # nf x nall x n_dim
470471
edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim
471472
h2: torch.Tensor, # nf x nloc x nnei x 3
472473
angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim
@@ -511,8 +512,6 @@ def forward(
511512
Updated angle embedding.
512513
"""
513514
nb, nloc, nnei, _ = edge_ebd.shape
514-
nall = node_ebd_ext.shape[1]
515-
node_ebd = node_ebd_ext[:, :nloc, :]
516515
assert (nb, nloc) == node_ebd.shape[:2]
517516
assert (nb, nloc, nnei) == h2.shape[:3]
518517
del a_nlist # may be used in the future
@@ -524,8 +523,10 @@ def forward(
524523
# node self mlp
525524
node_self_mlp = self.act(self.node_self_mlp(node_ebd))
526525
n_update_list.append(node_self_mlp)
527-
528-
nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist)
526+
if node_ebd_ext is not None:
527+
nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist)
528+
else:
529+
nei_node_ebd = _make_nei_g1(node_ebd, nlist)
529530

530531
# node sym (grrg + drrd)
531532
node_sym_list: list[torch.Tensor] = []

deepmd/pt/model/descriptor/repflows.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -461,16 +461,16 @@ def forward(
461461
# nb x nall x n_dim
462462
if comm_dict is None:
463463
assert mapping is not None
464-
mapping = (
465-
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
466-
)
464+
node_ebd_ext = None
465+
nlist = torch.gather(
466+
mapping,
467+
1,
468+
nlist.reshape(nframes, -1),
469+
).reshape(nlist.shape)
467470
for idx, ll in enumerate(self.layers):
468471
# node_ebd: nb x nloc x n_dim
469472
# node_ebd_ext: nb x nall x n_dim
470-
if comm_dict is None:
471-
assert mapping is not None
472-
node_ebd_ext = torch.gather(node_ebd, 1, mapping)
473-
else:
473+
if comm_dict is not None:
474474
has_spin = "has_spin" in comm_dict
475475
if not has_spin:
476476
n_padding = nall - nloc
@@ -528,6 +528,7 @@ def forward(
528528
node_ebd_real_ext, node_ebd_virtual_ext, real_nloc
529529
)
530530
node_ebd, edge_ebd, angle_ebd = ll.forward(
531+
node_ebd,
531532
node_ebd_ext,
532533
edge_ebd,
533534
h2,

0 commit comments

Comments
 (0)