Skip to content

Commit 93f5580

Browse files
authored
test(universal): replace Cartesian product with curated matrices for descriptor test (#5459)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Centralized descriptor parameter combinations into shared, reusable constants used across atomic/PT and model test suites to reduce repetition. * Added curated energy-model-specific parameter variants and a small helper to create fixed-argument variants for integration tests. * Removed an outdated descriptor helper and updated parameterized tests to consume the new centralized parameter collections. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/deepmodeling/deepmd-kit/pull/5459?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 2087416 commit 93f5580

5 files changed

Lines changed: 542 additions & 377 deletions

File tree

source/tests/universal/dpmodel/atomc_model/test_atomic_model.py

Lines changed: 67 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@
4141
)
4242
from ...dpmodel.descriptor.test_descriptor import (
4343
DescriptorParamDPA1,
44-
DescriptorParamDPA1List,
44+
DescriptorParamDPA1EnergyModelList,
4545
DescriptorParamDPA2,
46-
DescriptorParamDPA2List,
46+
DescriptorParamDPA2EnergyModelList,
4747
DescriptorParamHybrid,
4848
DescriptorParamHybridMixed,
4949
DescriptorParamHybridMixedTTebd,
5050
DescriptorParamSeA,
51-
DescriptorParamSeAList,
51+
DescriptorParamSeAEnergyModelList,
5252
DescriptorParamSeR,
53-
DescriptorParamSeRList,
53+
DescriptorParamSeREnergyModelList,
5454
DescriptorParamSeT,
55-
DescriptorParamSeTList,
55+
DescriptorParamSeTEnergyModelList,
5656
)
5757
from ...dpmodel.model.test_model import (
5858
skip_model_tests,
@@ -81,29 +81,64 @@ def make_sel_type_from_atom_exclude_types(type_map, atom_exclude_types):
8181
return sel_type.tolist()
8282

8383

84+
ENERGY_DESCRIPTOR_PARAMS = (
85+
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAEnergyModelList],
86+
*[(param_func, DescrptSeR) for param_func in DescriptorParamSeREnergyModelList],
87+
*[(param_func, DescrptSeT) for param_func in DescriptorParamSeTEnergyModelList],
88+
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1EnergyModelList],
89+
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2EnergyModelList],
90+
(DescriptorParamHybrid, DescrptHybrid),
91+
(DescriptorParamHybridMixed, DescrptHybrid),
92+
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
93+
)
94+
95+
DEFAULT_DESCRIPTOR_PARAMS = (
96+
(DescriptorParamSeA, DescrptSeA),
97+
(DescriptorParamSeR, DescrptSeR),
98+
(DescriptorParamSeT, DescrptSeT),
99+
(DescriptorParamDPA1, DescrptDPA1),
100+
(DescriptorParamDPA2, DescrptDPA2),
101+
)
102+
103+
DEFAULT_DESCRIPTOR_PARAMS_WITH_HYBRID = (
104+
*DEFAULT_DESCRIPTOR_PARAMS,
105+
(DescriptorParamHybrid, DescrptHybrid),
106+
(DescriptorParamHybridMixed, DescrptHybrid),
107+
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
108+
)
109+
110+
DEFAULT_VEC_DESCRIPTOR_PARAMS = (
111+
(DescriptorParamSeA, DescrptSeA),
112+
(DescriptorParamDPA1, DescrptDPA1),
113+
(DescriptorParamDPA2, DescrptDPA2),
114+
)
115+
116+
DEFAULT_VEC_DESCRIPTOR_PARAMS_WITH_HYBRID = (
117+
*DEFAULT_VEC_DESCRIPTOR_PARAMS,
118+
(DescriptorParamHybrid, DescrptHybrid),
119+
(DescriptorParamHybridMixed, DescrptHybrid),
120+
)
121+
122+
DEFAULT_DPA12_DESCRIPTOR_PARAMS = (
123+
(DescriptorParamDPA1, DescrptDPA1),
124+
(DescriptorParamDPA2, DescrptDPA2),
125+
)
126+
127+
DEFAULT_DPA12_DESCRIPTOR_PARAMS_WITH_HYBRID = (
128+
*DEFAULT_DPA12_DESCRIPTOR_PARAMS,
129+
(DescriptorParamHybridMixed, DescrptHybrid),
130+
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
131+
)
132+
133+
84134
@parameterized(
85135
des_parameterized=(
86-
(
87-
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
88-
*[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList],
89-
*[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList],
90-
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
91-
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
92-
(DescriptorParamHybrid, DescrptHybrid),
93-
(DescriptorParamHybridMixed, DescrptHybrid),
94-
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
95-
), # descrpt_class_param & class
136+
ENERGY_DESCRIPTOR_PARAMS, # descrpt_class_param & class
96137
((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class
97138
([], [0]), # atom_exclude_types
98139
),
99140
fit_parameterized=(
100-
(
101-
(DescriptorParamSeA, DescrptSeA),
102-
(DescriptorParamSeR, DescrptSeR),
103-
(DescriptorParamSeT, DescrptSeT),
104-
(DescriptorParamDPA1, DescrptDPA1),
105-
(DescriptorParamDPA2, DescrptDPA2),
106-
), # descrpt_class_param & class
141+
DEFAULT_DESCRIPTOR_PARAMS, # descrpt_class_param & class
107142
(
108143
*[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],
109144
), # fitting_class_param & class
@@ -158,27 +193,12 @@ def test_sel_type_from_atom_exclude_types(self):
158193

159194
@parameterized(
160195
des_parameterized=(
161-
(
162-
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
163-
*[(param_func, DescrptSeR) for param_func in DescriptorParamSeRList],
164-
*[(param_func, DescrptSeT) for param_func in DescriptorParamSeTList],
165-
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
166-
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
167-
(DescriptorParamHybrid, DescrptHybrid),
168-
(DescriptorParamHybridMixed, DescrptHybrid),
169-
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
170-
), # descrpt_class_param & class
196+
DEFAULT_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
171197
((FittingParamDos, DOSFittingNet),), # fitting_class_param & class
172198
([], [0]), # atom_exclude_types
173199
),
174200
fit_parameterized=(
175-
(
176-
(DescriptorParamSeA, DescrptSeA),
177-
(DescriptorParamSeR, DescrptSeR),
178-
(DescriptorParamSeT, DescrptSeT),
179-
(DescriptorParamDPA1, DescrptDPA1),
180-
(DescriptorParamDPA2, DescrptDPA2),
181-
), # descrpt_class_param & class
201+
DEFAULT_DESCRIPTOR_PARAMS, # descrpt_class_param & class
182202
(
183203
*[(param_func, DOSFittingNet) for param_func in FittingParamDosList],
184204
), # fitting_class_param & class
@@ -233,22 +253,12 @@ def test_sel_type_from_atom_exclude_types(self):
233253

234254
@parameterized(
235255
des_parameterized=(
236-
(
237-
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
238-
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
239-
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
240-
(DescriptorParamHybrid, DescrptHybrid),
241-
(DescriptorParamHybridMixed, DescrptHybrid),
242-
), # descrpt_class_param & class
256+
DEFAULT_VEC_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
243257
((FittingParamDipole, DipoleFitting),), # fitting_class_param & class
244258
([], [0]), # atom_exclude_types
245259
),
246260
fit_parameterized=(
247-
(
248-
(DescriptorParamSeA, DescrptSeA),
249-
(DescriptorParamDPA1, DescrptDPA1),
250-
(DescriptorParamDPA2, DescrptDPA2),
251-
), # descrpt_class_param & class
261+
DEFAULT_VEC_DESCRIPTOR_PARAMS, # descrpt_class_param & class
252262
(
253263
*[(param_func, DipoleFitting) for param_func in FittingParamDipoleList],
254264
), # fitting_class_param & class
@@ -304,22 +314,12 @@ def test_sel_type_from_atom_exclude_types(self):
304314

305315
@parameterized(
306316
des_parameterized=(
307-
(
308-
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
309-
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
310-
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
311-
(DescriptorParamHybrid, DescrptHybrid),
312-
(DescriptorParamHybridMixed, DescrptHybrid),
313-
), # descrpt_class_param & class
317+
DEFAULT_VEC_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
314318
((FittingParamPolar, PolarFitting),), # fitting_class_param & class
315319
([], [0]), # atom_exclude_types
316320
),
317321
fit_parameterized=(
318-
(
319-
(DescriptorParamSeA, DescrptSeA),
320-
(DescriptorParamDPA1, DescrptDPA1),
321-
(DescriptorParamDPA2, DescrptDPA2),
322-
), # descrpt_class_param & class
322+
DEFAULT_VEC_DESCRIPTOR_PARAMS, # descrpt_class_param & class
323323
(
324324
*[(param_func, PolarFitting) for param_func in FittingParamPolarList],
325325
), # fitting_class_param & class
@@ -375,19 +375,11 @@ def test_sel_type_from_atom_exclude_types(self):
375375

376376
@parameterized(
377377
des_parameterized=(
378-
(
379-
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
380-
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
381-
(DescriptorParamHybridMixed, DescrptHybrid),
382-
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
383-
), # descrpt_class_param & class
378+
DEFAULT_DPA12_DESCRIPTOR_PARAMS_WITH_HYBRID, # descrpt_class_param & class
384379
((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class
385380
),
386381
fit_parameterized=(
387-
(
388-
(DescriptorParamDPA1, DescrptDPA1),
389-
(DescriptorParamDPA2, DescrptDPA2),
390-
), # descrpt_class_param & class
382+
DEFAULT_DPA12_DESCRIPTOR_PARAMS, # descrpt_class_param & class
391383
(
392384
*[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],
393385
), # fitting_class_param & class
@@ -449,21 +441,15 @@ def setUpClass(cls) -> None:
449441
@parameterized(
450442
des_parameterized=(
451443
(
452-
*[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList],
453-
*[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List],
454-
*[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List],
444+
*DEFAULT_VEC_DESCRIPTOR_PARAMS,
455445
(DescriptorParamHybridMixed, DescrptHybrid),
456446
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
457447
), # descrpt_class_param & class
458448
((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class
459449
([], [0]), # atom_exclude_types
460450
),
461451
fit_parameterized=(
462-
(
463-
(DescriptorParamSeA, DescrptSeA),
464-
(DescriptorParamDPA1, DescrptDPA1),
465-
(DescriptorParamDPA2, DescrptDPA2),
466-
), # descrpt_class_param & class
452+
DEFAULT_VEC_DESCRIPTOR_PARAMS, # descrpt_class_param & class
467453
(
468454
*[
469455
(param_func, PropertyFittingNet)

0 commit comments

Comments
 (0)