Skip to content

Commit e76ddf2

Browse files
committed
fix(pd): correct tensor reshaping and indexing
1 parent f8f01cb commit e76ddf2

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

deepmd/pd/model/descriptor/repflow_layer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def optim_angle_update_dynamic(
586586
sub_node_update = paddle.matmul(node_ebd, sub_node)
587587
# n_angle * angle_dim
588588
sub_node_update = paddle.index_select(
589-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), n2a_index, 0
589+
sub_node_update.reshape((nf * nloc, sub_node_update.shape[-1])), n2a_index, 0
590590
)
591591

592592
# n_edge * angle_dim
@@ -666,7 +666,7 @@ def optim_edge_update_dynamic(
666666
sub_node_update = paddle.matmul(node_ebd, node)
667667
# n_edge * node/edge_dim
668668
sub_node_update = paddle.index_select(
669-
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]),
669+
sub_node_update.reshape((nf * nloc, sub_node_update.shape[-1])),
670670
n2e_index,
671671
0,
672672
)
@@ -675,7 +675,7 @@ def optim_edge_update_dynamic(
675675
sub_node_ext_update = paddle.matmul(node_ebd_ext, node_ext)
676676
# n_edge * node/edge_dim
677677
sub_node_ext_update = paddle.index_select(
678-
sub_node_ext_update.reshape(nf * nall, sub_node_update.shape[-1]),
678+
sub_node_ext_update.reshape((nf * nall, sub_node_update.shape[-1])),
679679
n_ext2e_index,
680680
0,
681681
)
@@ -746,7 +746,7 @@ def forward(
746746
a_updated : nf x nloc x a_nnei x a_nnei x a_dim
747747
Updated angle embedding.
748748
"""
749-
nb, nloc, nnei, _ = edge_ebd.shape
749+
nb, nloc, nnei = nlist.shape
750750
nall = node_ebd_ext.shape[1]
751751
node_ebd = node_ebd_ext[:, :nloc, :]
752752
n_edge = int(nlist_mask.sum().item())
@@ -896,7 +896,7 @@ def forward(
896896
n2e_index,
897897
average=False,
898898
num_owner=nb * nloc,
899-
).reshape(nb, nloc, node_edge_update.shape[-1])
899+
).reshape((nb, nloc, node_edge_update.shape[-1]))
900900
/ self.dynamic_e_sel
901901
)
902902
)

deepmd/pd/model/network/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def aggregate(
3030
output: [num_owner, feature_dim]
3131
"""
3232
bin_count = paddle.bincount(owners)
33-
bin_count = bin_count.where(bin_count != 0, paddle.ones_like(bin_count))
33+
bin_count = paddle.where(bin_count != 0, bin_count, paddle.ones_like(bin_count))
3434

3535
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
3636
difference = num_owner - bin_count.shape[0]
@@ -51,6 +51,7 @@ def get_graph_index(
5151
nlist_mask: paddle.Tensor,
5252
a_nlist_mask: paddle.Tensor,
5353
nall: int,
54+
use_loc_mapping: bool = True,
5455
):
5556
"""
5657
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
@@ -100,7 +101,9 @@ def get_graph_index(
100101
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]
101102

102103
# node_ext(j) to edge(ij) index_select
103-
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * nall
104+
frame_shift = paddle.arange(0, nf, dtype=nlist.dtype) * (
105+
nall if not use_loc_mapping else nloc
106+
)
104107
shifted_nlist = nlist + frame_shift[:, None, None]
105108
# n_edge
106109
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]

0 commit comments

Comments
 (0)