Skip to content

Commit 09564f3

Browse files
committed
extract nei_node_ebd
1 parent b543cc8 commit 09564f3

1 file changed

Lines changed: 5 additions & 10 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,8 @@ def optim_angle_update(
435435
def optim_edge_update(
436436
self,
437437
node_ebd: torch.Tensor,
438-
node_ebd_ext: torch.Tensor,
438+
nei_node_ebd: torch.Tensor,
439439
edge_ebd: torch.Tensor,
440-
nlist: torch.Tensor,
441440
feat: str = "node",
442441
) -> torch.Tensor:
443442
if feat == "node":
@@ -455,10 +454,8 @@ def optim_edge_update(
455454

456455
# nf * nloc * node/edge_dim
457456
sub_node_update = torch.matmul(node_ebd, node)
458-
# nf * nloc * nnei * node/edge_dim
459-
gathered_node_ebd_ext = _make_nei_g1(node_ebd_ext, nlist)
460-
# nf * nloc * nnei * node/edge_dim
461-
sub_node_ext_update = torch.matmul(gathered_node_ebd_ext, node_ext)
457+
# nf * nloc * node/edge_dim
458+
sub_node_ext_update = torch.matmul(nei_node_ebd, node_ext)
462459
# nf * nloc * nnei * node/edge_dim
463460
sub_edge_update = torch.matmul(edge_ebd, edge)
464461

@@ -577,9 +574,8 @@ def forward(
577574
node_edge_update = self.act(
578575
self.optim_edge_update(
579576
node_ebd,
580-
node_ebd_ext,
577+
nei_node_ebd,
581578
edge_ebd,
582-
nlist,
583579
"node",
584580
)
585581
) * sw.unsqueeze(-1)
@@ -605,9 +601,8 @@ def forward(
605601
edge_self_update = self.act(
606602
self.optim_edge_update(
607603
node_ebd,
608-
node_ebd_ext,
604+
nei_node_ebd,
609605
edge_ebd,
610-
nlist,
611606
"edge",
612607
)
613608
)

0 commit comments

Comments
 (0)