Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: trailing-whitespace
exclude: "^.+\\.pbtxt$"
- id: end-of-file-fixer
exclude: "^.+\\.pbtxt$|deeppot_sea.*\\.json$"
exclude: "^.+\\.pbtxt$|deeppot_sea.*|deeppot_dpa.*\\.json$"
- id: check-yaml
- id: check-json
- id: check-added-large-files
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def freeze(
None, # fparam
None, # aparam
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
False, # do_atomic_virial
False, # do_atomic_virial, NOTE: set to True if needed
],
full_graph=True,
)
Expand All @@ -396,7 +396,7 @@ def freeze(
None, # fparam
None, # aparam
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
False, # do_atomic_virial
False, # do_atomic_virial, NOTE: set to True if needed
(
InputSpec([-1], "int64", name="send_list"),
InputSpec([-1], "int32", name="send_proc"),
Expand Down
40 changes: 24 additions & 16 deletions deepmd/pd/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
"""
model.forward = paddle.jit.to_static(
model.forward,
full_graph=True,
input_spec=[
InputSpec([-1, -1, 3], dtype="float64", name="coord"),
InputSpec([-1, -1], dtype="int64", name="atype"),
InputSpec([-1, 9], dtype="float64", name="box"),
None,
None,
True,
InputSpec([-1, -1, 3], dtype="float64", name="coord"), # coord
InputSpec([-1, -1], dtype="int64", name="atype"), # atype
InputSpec([-1, 9], dtype="float64", name="box"), # box
None, # fparam
None, # aparam
True, # do_atomic_virial
],
full_graph=True,
)
""" example output shape and dtype of forward_lower
fetch_name_0: atom_energy [1, 192, 1] paddle.float64
Expand All @@ -86,17 +86,25 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
"""
model.forward_lower = paddle.jit.to_static(
model.forward_lower,
full_graph=True,
input_spec=[
InputSpec([-1, -1, 3], dtype="float64", name="coord"),
InputSpec([-1, -1], dtype="int32", name="atype"),
InputSpec([-1, -1, -1], dtype="int32", name="nlist"),
None,
None,
None,
True,
None,
InputSpec([-1, -1, 3], dtype="float64", name="coord"), # extended_coord
InputSpec([-1, -1], dtype="int32", name="atype"), # extended_atype
InputSpec([-1, -1, -1], dtype="int32", name="nlist"), # nlist
InputSpec([-1, -1], dtype="int64", name="mapping"), # mapping
None, # fparam
None, # aparam
True, # do_atomic_virial
(
InputSpec([-1], "int64", name="send_list"),
InputSpec([-1], "int32", name="send_proc"),
InputSpec([-1], "int32", name="recv_proc"),
InputSpec([-1], "int32", name="send_num"),
InputSpec([-1], "int32", name="recv_num"),
InputSpec([-1], "int64", name="communicator"),
# InputSpec([1], "int64", name="has_spin"),
), # comm_dict
],
full_graph=True,
)
paddle.jit.save(
model,
Expand Down
Loading
Loading