Skip to content

Commit 5a3f647

Browse files
fix
1 parent 89743ff commit 5a3f647

File tree

8 files changed

+16
-11
lines changed

8 files changed

+16
-11
lines changed

deepmd/pd/entrypoints/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,9 @@ def freeze(
368368
model.forward = paddle.jit.to_static(
369369
model.forward,
370370
input_spec=[
371-
InputSpec([1, -1, 3], dtype="float64", name="coord"), # coord
372-
InputSpec([1, -1], dtype="int64", name="atype"), # atype
373-
InputSpec([1, 9], dtype="float64", name="box"), # box
371+
InputSpec([-1, -1, 3], dtype="float64", name="coord"), # coord
372+
InputSpec([-1, -1], dtype="int64", name="atype"), # atype
373+
InputSpec([-1, 9], dtype="float64", name="box"), # box
374374
None, # fparam
375375
None, # aparam
376376
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial

deepmd/pd/model/descriptor/repflows.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,8 @@ def forward(
663663
place=paddle.CPUPlace(),
664664
), # should be int of c++, placed on cpu
665665
)
666+
if not paddle.in_dynamic_mode():
667+
ret = paddle.assign(ret)
666668
node_ebd_ext = ret.unsqueeze(0)
667669
if has_spin:
668670
node_ebd_real_ext, node_ebd_virtual_ext = paddle.split(

deepmd/pd/model/descriptor/repformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,8 @@ def forward(
548548
place=paddle.CPUPlace(),
549549
), # should be int of c++, placed on cpu
550550
)
551+
if not paddle.in_dynamic_mode():
552+
ret = paddle.assign(ret)
551553
g1_ext = ret.unsqueeze(0)
552554
if has_spin:
553555
g1_real_ext, g1_virtual_ext = paddle.split(

deepmd/pd/utils/serialization.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
6969
model.forward,
7070
full_graph=True,
7171
input_spec=[
72-
InputSpec([1, -1, 3], dtype="float64", name="coord"),
73-
InputSpec([1, -1], dtype="int64", name="atype"),
74-
InputSpec([1, 9], dtype="float64", name="box"),
72+
InputSpec([-1, -1, 3], dtype="float64", name="coord"),
73+
InputSpec([-1, -1], dtype="int64", name="atype"),
74+
InputSpec([-1, 9], dtype="float64", name="box"),
7575
None,
7676
None,
7777
True,
@@ -88,9 +88,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
8888
model.forward_lower,
8989
full_graph=True,
9090
input_spec=[
91-
InputSpec([1, -1, 3], dtype="float64", name="extended_coord"),
92-
InputSpec([1, -1], dtype="int32", name="extended_atype"),
93-
InputSpec([1, -1, -1], dtype="int32", name="nlist"),
91+
InputSpec([-1, -1, 3], dtype="float64", name="extended_coord"),
92+
InputSpec([-1, -1], dtype="int32", name="extended_atype"),
93+
InputSpec([-1, -1, -1], dtype="int32", name="nlist"),
9494
None,
9595
None,
9696
None,
@@ -101,4 +101,5 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
101101
paddle.jit.save(
102102
model,
103103
model_file.split(".json")[0],
104+
skip_prune_program=True,
104105
)

source/tests/infer/deeppot_sea.forward_lower.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.
420 Bytes
Binary file not shown.

source/tests/infer/deeppot_sea.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.
420 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)