forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathserialization.py
More file actions
113 lines (103 loc) · 3.62 KB
/
serialization.py
File metadata and controls
113 lines (103 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# SPDX-License-Identifier: LGPL-3.0-or-later
import paddle
from deepmd.pd.model.model.model import (
BaseModel,
)
def serialize_from_file(model_file: str) -> dict:
"""Serialize the model file to a dictionary.
Parameters
----------
model_file : str
The model file to be serialized.
Returns
-------
dict
The serialized model data.
"""
raise NotImplementedError("Paddle do not support jit.export yet.")
def deserialize_to_file(model_file: str, data: dict) -> None:
"""Deserialize the dictionary to a model file.
Parameters
----------
model_file : str
The model file to be saved.
data : dict
The dictionary to be deserialized.
"""
paddle.framework.core._set_prim_all_enabled(True)
if not model_file.endswith(".json"):
raise ValueError("Paddle backend only supports converting .json file")
model: paddle.nn.Layer = BaseModel.deserialize(data["model"])
model.eval()
# JIT will happy in this way...
if "min_nbor_dist" in data.get("@variables", {}):
model.register_buffer(
"buffer_min_nbor_dist",
paddle.to_tensor(
float(data["@variables"]["min_nbor_dist"]),
),
)
paddle.set_flags(
{
"FLAGS_save_cf_stack_op": 1,
"FLAGS_prim_enable_dynamic": 1,
"FLAGS_enable_pir_api": 1,
}
)
from paddle.static import (
InputSpec,
)
""" example output shape and dtype of forward
atom_energy: fetch_name_0 (1, 6, 1) float64
atom_virial: fetch_name_1 (1, 6, 1, 9) float64
energy: fetch_name_2 (1, 1) float64
force: fetch_name_3 (1, 6, 3) float64
mask: fetch_name_4 (1, 6) int32
virial: fetch_name_5 (1, 9) float64
"""
model.forward = paddle.jit.to_static(
model.forward,
input_spec=[
InputSpec([-1, -1, 3], dtype="float64", name="coord"), # coord
InputSpec([-1, -1], dtype="int64", name="atype"), # atype
InputSpec([-1, 9], dtype="float64", name="box"), # box
None, # fparam
None, # aparam
True, # do_atomic_virial
],
full_graph=True,
)
""" example output shape and dtype of forward_lower
fetch_name_0: atom_energy [1, 192, 1] paddle.float64
fetch_name_1: energy [1, 1] paddle.float64
fetch_name_2: extended_force [1, 5184, 3] paddle.float64
fetch_name_3: extended_virial [1, 5184, 1, 9] paddle.float64
fetch_name_4: virial [1, 9] paddle.float64
"""
model.forward_lower = paddle.jit.to_static(
model.forward_lower,
input_spec=[
InputSpec([-1, -1, 3], dtype="float64", name="coord"), # extended_coord
InputSpec([-1, -1], dtype="int32", name="atype"), # extended_atype
InputSpec([-1, -1, -1], dtype="int32", name="nlist"), # nlist
InputSpec([-1, -1], dtype="int64", name="mapping"), # mapping
None, # fparam
None, # aparam
True, # do_atomic_virial
(
InputSpec([-1], "int64", name="send_list"),
InputSpec([-1], "int32", name="send_proc"),
InputSpec([-1], "int32", name="recv_proc"),
InputSpec([-1], "int32", name="send_num"),
InputSpec([-1], "int32", name="recv_num"),
InputSpec([-1], "int64", name="communicator"),
# InputSpec([1], "int64", name="has_spin"),
), # comm_dict
],
full_graph=True,
)
paddle.jit.save(
model,
model_file.split(".json")[0],
skip_prune_program=True,
)