Skip to content

Commit ca142a3

Browse files
committed
use unsqueeze since slice overhead is larger
1 parent de84e1a commit ca142a3

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,9 @@ def _cal_hg_dynamic(
370370
# n_edge x e_dim
371371
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
372372
# n_edge x 3 x e_dim
373-
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
374-
-1, 3 * e_dim
375-
)
373+
flat_h2g2 = (
374+
flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)
375+
).reshape(-1, 3 * e_dim)
376376
# nf x nloc x 3 x e_dim
377377
h2g2 = (
378378
aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape(
@@ -1028,7 +1028,9 @@ def forward(
10281028
if not self.use_dynamic_sel:
10291029
# nb x nloc x a_nnei x a_nnei x e_dim
10301030
weighted_edge_angle_update = (
1031-
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update
1031+
a_sw.unsqueeze(-1).unsqueeze(-1)
1032+
* a_sw.unsqueeze(-2).unsqueeze(-1)
1033+
* edge_angle_update
10321034
)
10331035
# nb x nloc x a_nnei x e_dim
10341036
reduced_edge_angle_update = torch.sum(

0 commit comments

Comments
 (0)