Skip to content

Commit 85c8ad5

Browse files
refine more code
1 parent 3effbdb commit 85c8ad5

File tree

4 files changed

+4
-15
lines changed

4 files changed

+4
-15
lines changed

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TypeEmbedNetConsistent,
2626
)
2727
from deepmd.pd.utils import (
28+
decomp,
2829
env,
2930
)
3031
from deepmd.pd.utils.env import (
@@ -799,7 +800,8 @@ def forward(
799800
g1 = g1 + self.tebd_transform(g1_inp)
800801
# mapping g1
801802
if comm_dict is None or len(comm_dict) == 0:
802-
assert mapping.numel() > 0
803+
if paddle.in_dynamic_mode():
804+
assert decomp.numel(mapping) > 0
803805
mapping_ext = (
804806
mapping.reshape([nframes, nall])
805807
.unsqueeze(-1)

deepmd/pd/model/descriptor/repflows.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -630,12 +630,6 @@ def forward(
630630
)
631631

632632
assert len(comm_dict) >= 6
633-
# assert "send_list" in comm_dict
634-
# assert "send_proc" in comm_dict
635-
# assert "recv_proc" in comm_dict
636-
# assert "send_num" in comm_dict
637-
# assert "recv_num" in comm_dict
638-
# assert "communicator" in comm_dict
639633
ret = paddle_ops_deepmd_border_op(
640634
comm_dict[0],
641635
comm_dict[1],

deepmd/pd/model/descriptor/repformer_layer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def _make_nei_g1(
110110
# index: nb x (nloc x nnei) x ng1
111111
index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, ng1])
112112
# gg1 : nb x (nloc x nnei) x ng1
113-
# print(g1_ext.shape, index.shape)
114113
gg1 = paddle.take_along_axis(g1_ext, indices=index, axis=1, broadcast=False)
115114
# gg1 : nb x nloc x nnei x ng1
116115
gg1 = gg1.reshape([nb, nloc, nnei, ng1])

deepmd/pd/model/descriptor/repformers.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -510,13 +510,7 @@ def forward(
510510
mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0
511511
)
512512

513-
# assert "send_list" in comm_dict
514-
# assert "send_proc" in comm_dict
515-
# assert "recv_proc" in comm_dict
516-
# assert "send_num" in comm_dict
517-
# assert "recv_num" in comm_dict
518-
# assert "communicator" in comm_dict
519-
# print(f"g1.shape = ", g1.shape)
513+
assert len(comm_dict) >= 6
520514
ret = paddle_ops_deepmd_border_op(
521515
comm_dict[0],
522516
comm_dict[1],

0 commit comments

Comments
 (0)