5757from .les_model import (
5858 LESEnergyModel ,
5959)
60+ from .les_vmap_model import (
61+ LESVmapModel ,
62+ )
6063from .make_hessian_model import (
6164 make_hessian_model ,
6265)
7578from .sog_model import (
7679 SOGEnergyModel ,
7780)
81+ from .sog_vmap_model import (
82+ SOGVmapModel ,
83+ )
7884from .spin_model import (
7985 SpinEnergyModel ,
8086 SpinModel ,
@@ -298,6 +304,62 @@ def get_standard_model(model_params: dict) -> BaseModel:
298304 return model
299305
300306
307+ def _get_lr_vmap_model (
308+ model_params : dict ,
309+ modelcls : type [BaseModel ],
310+ expected_fitting_type : str ,
311+ ) -> BaseModel :
312+ model_params_old = model_params
313+ model_params = copy .deepcopy (model_params )
314+ ntypes = len (model_params ["type_map" ])
315+ descriptor , fitting , fitting_net_type = _get_standard_model_components (
316+ model_params , ntypes
317+ )
318+ if fitting_net_type != expected_fitting_type :
319+ raise RuntimeError (
320+ f"{ modelcls .__name__ } requires fitting_net.type='{ expected_fitting_type } ', "
321+ f"got '{ fitting_net_type } '."
322+ )
323+
324+ atom_exclude_types = model_params .get ("atom_exclude_types" , [])
325+ pair_exclude_types = model_params .get ("pair_exclude_types" , [])
326+ preset_out_bias = model_params .get ("preset_out_bias" )
327+ preset_out_bias = _convert_preset_out_bias_to_array (
328+ preset_out_bias , model_params ["type_map" ]
329+ )
330+ data_stat_protect = model_params .get ("data_stat_protect" , 1e-2 )
331+
332+ model = modelcls (
333+ descriptor = descriptor ,
334+ fitting = fitting ,
335+ type_map = model_params ["type_map" ],
336+ atom_exclude_types = atom_exclude_types ,
337+ pair_exclude_types = pair_exclude_types ,
338+ preset_out_bias = preset_out_bias ,
339+ data_stat_protect = data_stat_protect ,
340+ )
341+ if model_params .get ("hessian_mode" ):
342+ model .enable_hessian ()
343+ model .model_def_script = json .dumps (model_params_old )
344+ return model
345+
346+
347+ def get_sog_vmap_model (model_params : dict ) -> BaseModel :
348+ return _get_lr_vmap_model (
349+ model_params ,
350+ modelcls = SOGVmapModel ,
351+ expected_fitting_type = "sog_energy" ,
352+ )
353+
354+
355+ def get_les_vmap_model (model_params : dict ) -> BaseModel :
356+ return _get_lr_vmap_model (
357+ model_params ,
358+ modelcls = LESVmapModel ,
359+ expected_fitting_type = "les_energy" ,
360+ )
361+
362+
301363def get_model (model_params : dict ) -> Any :
302364 model_type = model_params .get ("type" , "standard" )
303365 if model_type == "standard" :
@@ -307,6 +369,10 @@ def get_model(model_params: dict) -> Any:
307369 return get_zbl_model (model_params )
308370 else :
309371 return get_standard_model (model_params )
372+ elif model_type == "sog_vmap" :
373+ return get_sog_vmap_model (model_params )
374+ elif model_type == "les_vmap" :
375+ return get_les_vmap_model (model_params )
310376 elif model_type == "linear_ener" :
311377 return get_linear_model (model_params )
312378 else :
@@ -322,9 +388,11 @@ def get_model(model_params: dict) -> Any:
322388 "EnergyModel" ,
323389 "FrozenModel" ,
324390 "LESEnergyModel" ,
391+ "LESVmapModel" ,
325392 "LinearEnergyModel" ,
326393 "PolarModel" ,
327394 "SOGEnergyModel" ,
395+ "SOGVmapModel" ,
328396 "SpinEnergyModel" ,
329397 "SpinModel" ,
330398 "get_model" ,
0 commit comments