Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 67 additions & 81 deletions source/tests/universal/dpmodel/atomc_model/test_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@
)
from ...dpmodel.descriptor.test_descriptor import (
DescriptorParamDPA1,
DescriptorParamDPA1List,
DescriptorParamDPA1EnergyModelList,
DescriptorParamDPA2,
DescriptorParamDPA2List,
DescriptorParamDPA2EnergyModelList,
DescriptorParamHybrid,
DescriptorParamHybridMixed,
DescriptorParamHybridMixedTTebd,
DescriptorParamSeA,
DescriptorParamSeAList,
DescriptorParamSeAEnergyModelList,
DescriptorParamSeR,
DescriptorParamSeRList,
DescriptorParamSeREnergyModelList,
DescriptorParamSeT,
DescriptorParamSeTList,
DescriptorParamSeTEnergyModelList,
)
from ...dpmodel.model.test_model import (
skip_model_tests,
Expand Down Expand Up @@ -81,29 +81,64 @@ def make_sel_type_from_atom_exclude_types(type_map, atom_exclude_types):
return sel_type.tolist()


ENERGY_DESCRIPTOR_PARAMS = (
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAEnergyModelList],
*[(param_func, DescrptSeR) for param_func in DescriptorParamSeREnergyModelList],
*[(param_func, DescrptSeT) for param_func in DescriptorParamSeTEnergyModelList],
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1EnergyModelList],
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2EnergyModelList],
(DescriptorParamHybrid, DescrptHybrid),
(DescriptorParamHybridMixed, DescrptHybrid),
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
)

DEFAULT_DESCRIPTOR_PARAMS = (
(DescriptorParamSeA, DescrptSeA),
(DescriptorParamSeR, DescrptSeR),
(DescriptorParamSeT, DescrptSeT),
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
)

DEFAULT_DESCRIPTOR_PARAMS_WITH_HYBRID = (
*DEFAULT_DESCRIPTOR_PARAMS,
(DescriptorParamHybrid, DescrptHybrid),
(DescriptorParamHybridMixed, DescrptHybrid),
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
)

DEFAULT_VEC_DESCRIPTOR_PARAMS = (
(DescriptorParamSeA, DescrptSeA),
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
)

DEFAULT_VEC_DESCRIPTOR_PARAMS_WITH_HYBRID = (
*DEFAULT_VEC_DESCRIPTOR_PARAMS,
(DescriptorParamHybrid, DescrptHybrid),
(DescriptorParamHybridMixed, DescrptHybrid),
)

DEFAULT_DPA12_DESCRIPTOR_PARAMS = (
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
)

DEFAULT_DPA12_DESCRIPTOR_PARAMS_WITH_HYBRID = (
*DEFAULT_DPA12_DESCRIPTOR_PARAMS,
(DescriptorParamHybridMixed, DescrptHybrid),
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
)


@parameterized(
des_parameterized=(
(
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
*[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList],
*[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList],
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
(DescriptorParamHybrid, DescrptHybrid),
(DescriptorParamHybridMixed, DescrptHybrid),
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
), # descrpt_class_param & class
ENERGY_DESCRIPTOR_PARAMS, # descrpt_class_param & class
((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
(DescriptorParamSeA, DescrptSeA),
(DescriptorParamSeR, DescrptSeR),
(DescriptorParamSeT, DescrptSeT),
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
), # descrpt_class_param & class
DEFAULT_DESCRIPTOR_PARAMS, # descrpt_class_param & class
(
*[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],
), # fitting_class_param & class
Expand Down Expand Up @@ -158,27 +193,12 @@ def test_sel_type_from_atom_exclude_types(self):

@parameterized(
des_parameterized=(
(
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
*[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList],
*[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList],
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
(DescriptorParamHybrid, DescrptHybrid),
(DescriptorParamHybridMixed, DescrptHybrid),
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
), # descrpt_class_param & class
DEFAULT_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
((FittingParamDos, DOSFittingNet),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
(DescriptorParamSeA, DescrptSeA),
(DescriptorParamSeR, DescrptSeR),
(DescriptorParamSeT, DescrptSeT),
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
), # descrpt_class_param & class
DEFAULT_DESCRIPTOR_PARAMS, # descrpt_class_param & class
(
*[(param_func, DOSFittingNet) for param_func in FittingParamDosList],
), # fitting_class_param & class
Expand Down Expand Up @@ -233,22 +253,12 @@ def test_sel_type_from_atom_exclude_types(self):

@parameterized(
des_parameterized=(
(
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
(DescriptorParamHybrid, DescrptHybrid),
(DescriptorParamHybridMixed, DescrptHybrid),
), # descrpt_class_param & class
DEFAULT_VEC_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
((FittingParamDipole, DipoleFitting),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
(DescriptorParamSeA, DescrptSeA),
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
), # descrpt_class_param & class
DEFAULT_VEC_DESCRIPTOR_PARAMS, # descrpt_class_param & class
(
*[(param_func, DipoleFitting) for param_func in FittingParamDipoleList],
), # fitting_class_param & class
Expand Down Expand Up @@ -304,22 +314,12 @@ def test_sel_type_from_atom_exclude_types(self):

@parameterized(
des_parameterized=(
(
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
(DescriptorParamHybrid, DescrptHybrid),
(DescriptorParamHybridMixed, DescrptHybrid),
), # descrpt_class_param & class
DEFAULT_VEC_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
((FittingParamPolar, PolarFitting),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
(DescriptorParamSeA, DescrptSeA),
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
), # descrpt_class_param & class
DEFAULT_VEC_DESCRIPTOR_PARAMS, # descrpt_class_param & class
(
*[(param_func, PolarFitting) for param_func in FittingParamPolarList],
), # fitting_class_param & class
Expand Down Expand Up @@ -375,19 +375,11 @@ def test_sel_type_from_atom_exclude_types(self):

@parameterized(
des_parameterized=(
(
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
(DescriptorParamHybridMixed, DescrptHybrid),
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
), # descrpt_class_param & class
DEFAULT_DPA12_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class
),
fit_parameterized=(
(
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
), # descrpt_class_param & class
DEFAULT_DPA12_DESCRIPTOR_PARAMS, # descrpt_class_param & class
(
*[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],
), # fitting_class_param & class
Expand Down Expand Up @@ -449,21 +441,15 @@ def setUpClass(cls) -> None:
@parameterized(
des_parameterized=(
(
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
*DEFAULT_VEC_DESCRIPTOR_PARAMS,
(DescriptorParamHybridMixed, DescrptHybrid),
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
), # descrpt_class_param & class
((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
(DescriptorParamSeA, DescrptSeA),
(DescriptorParamDPA1, DescrptDPA1),
(DescriptorParamDPA2, DescrptDPA2),
), # descrpt_class_param & class
DEFAULT_VEC_DESCRIPTOR_PARAMS, # descrpt_class_param & class
(
*[
(param_func, PropertyFittingNet)
Expand Down
Loading
Loading