4242from deepmd .pt .modifier .base_modifier import (
4343 BaseModifier ,
4444)
45- from deepmd .pt .utils import (
46- env ,
47- )
48- from deepmd .pt .utils .serialization import (
49- serialize_from_file ,
45+ from deepmd .pt .modifier .dp_modifier import (
46+ DPModifier ,
5047)
5148from deepmd .pt .utils .utils import (
5249 to_numpy_array ,
@@ -188,44 +185,19 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N
188185
189186
190187@BaseModifier .register ("scaling_tester" )
191- class ModifierScalingTester (BaseModifier ):
192- def __new__ (
193- cls ,
194- * args : tuple ,
195- model : str | None = None ,
196- model_name : str | None = None ,
197- ** kwargs : dict ,
198- ) -> "ModifierScalingTester" :
199- return super ().__new__ (cls , model_name if model_name is not None else model )
200-
188+ class ModifierScalingTester (DPModifier ):
201189 def __init__ (
202190 self ,
203- model : torch .nn .Module | None = None ,
204- model_name : str | None = None ,
191+ dp_model : torch .nn .Module | None = None ,
192+ dp_model_file_name : str | None = None ,
205193 sfactor : float = 1.0 ,
206194 use_cache : bool = True ,
207195 ) -> None :
208196 """Initialize a test modifier that applies scaled model predictions using a frozen model."""
209- super ().__init__ (use_cache )
197+ super ().__init__ (dp_model , dp_model_file_name , use_cache )
210198 self .modifier_type = "scaling_tester"
211199 self .sfactor = sfactor
212200
213- if model_name is None and model is None :
214- raise AttributeError ("`model_name` or `model` should be specified." )
215- if model_name is not None and model is not None :
216- raise AttributeError (
217- "`model_name` and `model` cannot be used simultaneously."
218- )
219-
220- if model is not None :
221- self ._model = model .to (env .DEVICE )
222- if model_name is not None :
223- data = serialize_from_file (model_name )
224- self ._model = EnergyModel .deserialize (data ["model" ]).to (env .DEVICE )
225-
226- # use jit model for inference
227- self .model = torch .jit .script (self ._model )
228-
229201 def serialize (self ) -> dict :
230202 """Serialize the modifier.
231203
@@ -234,10 +206,9 @@ def serialize(self) -> dict:
234206 dict
235207 The serialized data
236208 """
237- dd = BaseModifier .serialize (self )
209+ dd = super () .serialize ()
238210 dd .update (
239211 {
240- "model" : self ._model .serialize (),
241212 "sfactor" : self .sfactor ,
242213 }
243214 )
@@ -249,8 +220,8 @@ def deserialize(cls, data: dict) -> "ModifierScalingTester":
249220 data .pop ("@class" , None )
250221 data .pop ("type" , None )
251222 data .pop ("@version" , None )
252- model_obj = EnergyModel .deserialize (data .pop ("model " ))
253- data ["model " ] = model_obj
223+ model_obj = EnergyModel .deserialize (data .pop ("dp_model " ))
224+ data ["dp_model " ] = model_obj
254225 return cls (** data )
255226
256227 def forward (
0 commit comments