Skip to content

Commit 38fb926

Browse files
refine more code
1 parent 3effbdb commit 38fb926

File tree

4 files changed

+30
-38
lines changed

4 files changed

+30
-38
lines changed

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TypeEmbedNetConsistent,
2626
)
2727
from deepmd.pd.utils import (
28+
decomp,
2829
env,
2930
)
3031
from deepmd.pd.utils.env import (
@@ -799,7 +800,8 @@ def forward(
799800
g1 = g1 + self.tebd_transform(g1_inp)
800801
# mapping g1
801802
if comm_dict is None or len(comm_dict) == 0:
802-
assert mapping.numel() > 0
803+
if paddle.in_dynamic_mode():
804+
assert decomp.numel(mapping) > 0
803805
mapping_ext = (
804806
mapping.reshape([nframes, nall])
805807
.unsqueeze(-1)

deepmd/pd/model/descriptor/repflows.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
Union,
66
)
77

8-
import numpy as np
98
import paddle
109

1110
from deepmd.dpmodel.utils.seed import (
@@ -525,7 +524,7 @@ def forward(
525524
# nf x nloc x a_nnei x a_nnei
526525
# 1 - 1e-6 for paddle.acos stability
527526
cosine_ij = paddle.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6)
528-
angle_input = cosine_ij.unsqueeze(-1) / (np.pi**0.5)
527+
angle_input = cosine_ij.unsqueeze(-1) / (paddle.pi**0.5)
529528

530529
if not parallel_mode and self.use_loc_mapping:
531530
if paddle.in_dynamic_mode():
@@ -597,19 +596,21 @@ def forward(
597596
has_spin = len(comm_dict) >= 7
598597
if not has_spin:
599598
n_padding = nall - nloc
600-
# node_ebd = paddle.nn.functional.pad(
601-
# node_ebd.squeeze(0), [0, 0, 0, n_padding], value=0.0
602-
# )
599+
if paddle.in_dynamic_mode():
600+
node_ebd = paddle.nn.functional.pad(
601+
node_ebd.squeeze(0), [0, 0, 0, n_padding], value=0.0
602+
)
603+
else:
604+
_fill_shape = node_ebd.shape[1:]
605+
_fill_shape[1] = n_padding
606+
node_ebd = paddle.concat(
607+
[
608+
node_ebd.squeeze(0),
609+
paddle.zeros(_fill_shape, dtype=node_ebd.dtype),
610+
],
611+
axis=1,
612+
)
603613
# [nframes, nloc, tebd_dim]
604-
_shapes = node_ebd.shape[1:]
605-
_shapes[1] = n_padding
606-
node_ebd = paddle.concat(
607-
[
608-
node_ebd.squeeze(0),
609-
paddle.zeros(_shapes, dtype=node_ebd.dtype),
610-
],
611-
axis=1,
612-
)
613614
real_nloc = nloc
614615
real_nall = nall
615616
else:
@@ -630,12 +631,6 @@ def forward(
630631
)
631632

632633
assert len(comm_dict) >= 6
633-
# assert "send_list" in comm_dict
634-
# assert "send_proc" in comm_dict
635-
# assert "recv_proc" in comm_dict
636-
# assert "send_num" in comm_dict
637-
# assert "recv_num" in comm_dict
638-
# assert "communicator" in comm_dict
639634
ret = paddle_ops_deepmd_border_op(
640635
comm_dict[0],
641636
comm_dict[1],

deepmd/pd/model/descriptor/repformer_layer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def _make_nei_g1(
110110
# index: nb x (nloc x nnei) x ng1
111111
index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, ng1])
112112
# gg1 : nb x (nloc x nnei) x ng1
113-
# print(g1_ext.shape, index.shape)
114113
gg1 = paddle.take_along_axis(g1_ext, indices=index, axis=1, broadcast=False)
115114
# gg1 : nb x nloc x nnei x ng1
116115
gg1 = gg1.reshape([nb, nloc, nnei, ng1])

deepmd/pd/model/descriptor/repformers.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -484,15 +484,17 @@ def forward(
484484
has_spin = len(comm_dict) >= 7
485485
if not has_spin:
486486
n_padding = nall - nloc
487-
# g1 = paddle.nn.functional.pad(
488-
# g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
489-
# )
490-
_shapes = g1.shape[1:]
491-
_shapes[1] = n_padding
492-
g1 = paddle.concat(
493-
[g1.squeeze(0), paddle.zeros(_shapes, dtype=g1.dtype)],
494-
axis=1,
495-
)
487+
if paddle.in_dynamic_mode():
488+
g1 = paddle.nn.functional.pad(
489+
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
490+
)
491+
else:
492+
_fill_shape = g1.shape[1:]
493+
_fill_shape[1] = n_padding
494+
g1 = paddle.concat(
495+
[g1.squeeze(0), paddle.zeros(_fill_shape, dtype=g1.dtype)],
496+
axis=1,
497+
)
496498
real_nloc = nloc
497499
real_nall = nall
498500
else:
@@ -510,13 +512,7 @@ def forward(
510512
mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0
511513
)
512514

513-
# assert "send_list" in comm_dict
514-
# assert "send_proc" in comm_dict
515-
# assert "recv_proc" in comm_dict
516-
# assert "send_num" in comm_dict
517-
# assert "recv_num" in comm_dict
518-
# assert "communicator" in comm_dict
519-
# print(f"g1.shape = ", g1.shape)
515+
assert len(comm_dict) >= 6
520516
ret = paddle_ops_deepmd_border_op(
521517
comm_dict[0],
522518
comm_dict[1],

0 commit comments

Comments
 (0)