Skip to content

Commit 7fd4801

Browse files
committed
add DPModifier
1 parent f87b4ce commit 7fd4801

2 files changed

Lines changed: 101 additions & 38 deletions

File tree

deepmd/pt/modifier/dp_modifier.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
import torch
4+
5+
from deepmd.pt.model.model import (
6+
BaseModel,
7+
)
8+
from deepmd.pt.modifier.base_modifier import (
9+
BaseModifier,
10+
)
11+
from deepmd.pt.utils import (
12+
env,
13+
)
14+
from deepmd.pt.utils.serialization import (
15+
serialize_from_file,
16+
)
17+
18+
19+
class DPModifier(BaseModifier):
20+
def __init__(
21+
self,
22+
dp_model: torch.nn.Module | None = None,
23+
dp_model_file_name: str | None = None,
24+
use_cache: bool = True,
25+
**kwargs,
26+
) -> None:
27+
"""Constructor."""
28+
super().__init__(use_cache=use_cache)
29+
30+
if dp_model_file_name is None and dp_model is None:
31+
raise AttributeError("`model_name` or `model` should be specified.")
32+
if dp_model_file_name is not None and dp_model is not None:
33+
raise AttributeError(
34+
"`model_name` and `model` cannot be used simultaneously."
35+
)
36+
37+
if dp_model is not None:
38+
self._model = dp_model.to(env.DEVICE)
39+
if dp_model_file_name is not None:
40+
data = serialize_from_file(dp_model_file_name)
41+
assert data["model"]["type"] == "standard"
42+
self._model = (
43+
BaseModel.get_class_by_type(data["model"]["fitting"]["type"])
44+
.deserialize(data["model"])
45+
.to(env.DEVICE)
46+
)
47+
self._model.eval()
48+
# use jit model for inference
49+
self.model = torch.jit.script(self._model)
50+
51+
def serialize(self) -> dict:
52+
"""Serialize the modifier.
53+
54+
Returns
55+
-------
56+
dict
57+
The serialized data
58+
"""
59+
dd = BaseModifier.serialize(self)
60+
dd.update(
61+
{
62+
"dp_model": self._model.serialize(),
63+
}
64+
)
65+
return dd
66+
67+
@classmethod
68+
def get_modifier(cls, modifier_params: dict) -> "DPModifier":
69+
"""Get the modifier by the parameters.
70+
71+
By default, all the parameters are directly passed to the constructor.
72+
If not, override this method.
73+
74+
Parameters
75+
----------
76+
modifier_params : dict
77+
The modifier parameters
78+
79+
Returns
80+
-------
81+
BaseModifier
82+
The modifier
83+
"""
84+
modifier_params = modifier_params.copy()
85+
modifier_params.pop("type", None)
86+
# convert model_name str
87+
model_name = modifier_params.pop("model_name", None)
88+
if model_name is not None:
89+
modifier_params["dp_model"] = None
90+
modifier_params["dp_model_file_name"] = model_name
91+
modifier = cls(**modifier_params)
92+
return modifier

source/tests/pt/modifier/test_data_modifier.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,8 @@
4242
from deepmd.pt.modifier.base_modifier import (
4343
BaseModifier,
4444
)
45-
from deepmd.pt.utils import (
46-
env,
47-
)
48-
from deepmd.pt.utils.serialization import (
49-
serialize_from_file,
45+
from deepmd.pt.modifier.dp_modifier import (
46+
DPModifier,
5047
)
5148
from deepmd.pt.utils.utils import (
5249
to_numpy_array,
@@ -188,44 +185,19 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N
188185

189186

190187
@BaseModifier.register("scaling_tester")
191-
class ModifierScalingTester(BaseModifier):
192-
def __new__(
193-
cls,
194-
*args: tuple,
195-
model: str | None = None,
196-
model_name: str | None = None,
197-
**kwargs: dict,
198-
) -> "ModifierScalingTester":
199-
return super().__new__(cls, model_name if model_name is not None else model)
200-
188+
class ModifierScalingTester(DPModifier):
201189
def __init__(
202190
self,
203-
model: torch.nn.Module | None = None,
204-
model_name: str | None = None,
191+
dp_model: torch.nn.Module | None = None,
192+
dp_model_file_name: str | None = None,
205193
sfactor: float = 1.0,
206194
use_cache: bool = True,
207195
) -> None:
208196
"""Initialize a test modifier that applies scaled model predictions using a frozen model."""
209-
super().__init__(use_cache)
197+
super().__init__(dp_model, dp_model_file_name, use_cache)
210198
self.modifier_type = "scaling_tester"
211199
self.sfactor = sfactor
212200

213-
if model_name is None and model is None:
214-
raise AttributeError("`model_name` or `model` should be specified.")
215-
if model_name is not None and model is not None:
216-
raise AttributeError(
217-
"`model_name` and `model` cannot be used simultaneously."
218-
)
219-
220-
if model is not None:
221-
self._model = model.to(env.DEVICE)
222-
if model_name is not None:
223-
data = serialize_from_file(model_name)
224-
self._model = EnergyModel.deserialize(data["model"]).to(env.DEVICE)
225-
226-
# use jit model for inference
227-
self.model = torch.jit.script(self._model)
228-
229201
def serialize(self) -> dict:
230202
"""Serialize the modifier.
231203
@@ -234,10 +206,9 @@ def serialize(self) -> dict:
234206
dict
235207
The serialized data
236208
"""
237-
dd = BaseModifier.serialize(self)
209+
dd = super().serialize()
238210
dd.update(
239211
{
240-
"model": self._model.serialize(),
241212
"sfactor": self.sfactor,
242213
}
243214
)
@@ -249,8 +220,8 @@ def deserialize(cls, data: dict) -> "ModifierScalingTester":
249220
data.pop("@class", None)
250221
data.pop("type", None)
251222
data.pop("@version", None)
252-
model_obj = EnergyModel.deserialize(data.pop("model"))
253-
data["model"] = model_obj
223+
model_obj = EnergyModel.deserialize(data.pop("dp_model"))
224+
data["dp_model"] = model_obj
254225
return cls(**data)
255226

256227
def forward(

0 commit comments

Comments
 (0)