Skip to content

Commit 67837f4

Browse files
committed
feat:
1.add SOG and LES vector-map models 2.correct bias_atom_q in lr_fitting
1 parent eb328b5 commit 67837f4

5 files changed

Lines changed: 880 additions & 27 deletions

File tree

deepmd/pt/model/model/__init__.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757
from .les_model import (
5858
LESEnergyModel,
5959
)
60+
from .les_vmap_model import (
61+
LESVmapModel,
62+
)
6063
from .make_hessian_model import (
6164
make_hessian_model,
6265
)
@@ -75,6 +78,9 @@
7578
from .sog_model import (
7679
SOGEnergyModel,
7780
)
81+
from .sog_vmap_model import (
82+
SOGVmapModel,
83+
)
7884
from .spin_model import (
7985
SpinEnergyModel,
8086
SpinModel,
@@ -298,6 +304,62 @@ def get_standard_model(model_params: dict) -> BaseModel:
298304
return model
299305

300306

307+
def _get_lr_vmap_model(
308+
model_params: dict,
309+
modelcls: type[BaseModel],
310+
expected_fitting_type: str,
311+
) -> BaseModel:
312+
model_params_old = model_params
313+
model_params = copy.deepcopy(model_params)
314+
ntypes = len(model_params["type_map"])
315+
descriptor, fitting, fitting_net_type = _get_standard_model_components(
316+
model_params, ntypes
317+
)
318+
if fitting_net_type != expected_fitting_type:
319+
raise RuntimeError(
320+
f"{modelcls.__name__} requires fitting_net.type='{expected_fitting_type}', "
321+
f"got '{fitting_net_type}'."
322+
)
323+
324+
atom_exclude_types = model_params.get("atom_exclude_types", [])
325+
pair_exclude_types = model_params.get("pair_exclude_types", [])
326+
preset_out_bias = model_params.get("preset_out_bias")
327+
preset_out_bias = _convert_preset_out_bias_to_array(
328+
preset_out_bias, model_params["type_map"]
329+
)
330+
data_stat_protect = model_params.get("data_stat_protect", 1e-2)
331+
332+
model = modelcls(
333+
descriptor=descriptor,
334+
fitting=fitting,
335+
type_map=model_params["type_map"],
336+
atom_exclude_types=atom_exclude_types,
337+
pair_exclude_types=pair_exclude_types,
338+
preset_out_bias=preset_out_bias,
339+
data_stat_protect=data_stat_protect,
340+
)
341+
if model_params.get("hessian_mode"):
342+
model.enable_hessian()
343+
model.model_def_script = json.dumps(model_params_old)
344+
return model
345+
346+
347+
def get_sog_vmap_model(model_params: dict) -> BaseModel:
348+
return _get_lr_vmap_model(
349+
model_params,
350+
modelcls=SOGVmapModel,
351+
expected_fitting_type="sog_energy",
352+
)
353+
354+
355+
def get_les_vmap_model(model_params: dict) -> BaseModel:
356+
return _get_lr_vmap_model(
357+
model_params,
358+
modelcls=LESVmapModel,
359+
expected_fitting_type="les_energy",
360+
)
361+
362+
301363
def get_model(model_params: dict) -> Any:
302364
model_type = model_params.get("type", "standard")
303365
if model_type == "standard":
@@ -307,6 +369,10 @@ def get_model(model_params: dict) -> Any:
307369
return get_zbl_model(model_params)
308370
else:
309371
return get_standard_model(model_params)
372+
elif model_type == "sog_vmap":
373+
return get_sog_vmap_model(model_params)
374+
elif model_type == "les_vmap":
375+
return get_les_vmap_model(model_params)
310376
elif model_type == "linear_ener":
311377
return get_linear_model(model_params)
312378
else:
@@ -322,9 +388,11 @@ def get_model(model_params: dict) -> Any:
322388
"EnergyModel",
323389
"FrozenModel",
324390
"LESEnergyModel",
391+
"LESVmapModel",
325392
"LinearEnergyModel",
326393
"PolarModel",
327394
"SOGEnergyModel",
395+
"SOGVmapModel",
328396
"SpinEnergyModel",
329397
"SpinModel",
330398
"get_model",

0 commit comments

Comments
 (0)