Skip to content

Commit ad23558

Browse files
committed
perf: gather node embedding before matmul
1 parent 43e0288 commit ad23558

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,10 @@ def optim_edge_update(
455455

456456
# nf * nloc * node/edge_dim
457457
sub_node_update = torch.matmul(node_ebd, node)
458-
# nf * nall * node/edge_dim
459-
sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext)
460458
# nf * nloc * nnei * node/edge_dim
461-
sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist)
459+
gathered_node_ebd_ext = _make_nei_g1(node_ebd_ext, nlist)
460+
# nf * nloc * nnei * node/edge_dim
461+
sub_node_ext_update = torch.matmul(gathered_node_ebd_ext, node_ext)
462462
# nf * nloc * nnei * node/edge_dim
463463
sub_edge_update = torch.matmul(edge_ebd, edge)
464464

0 commit comments

Comments
 (0)