Skip to content

Commit 03d2141

Browse files
committed
add ut
1 parent 9b1a711 commit 03d2141

1 file changed

Lines changed: 15 additions & 1 deletion

File tree

source/tests/consistent/descriptor/test_dpa3.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
"n_multi_edge_message",
8282
"precision",
8383
"add_chg_spin_ebd",
84+
"default_chg_spin",
8485
)
8586

8687

@@ -100,6 +101,7 @@
100101
"n_multi_edge_message": 1,
101102
"precision": "float64",
102103
"add_chg_spin_ebd": False,
104+
"default_chg_spin": None,
103105
}
104106

105107

@@ -123,6 +125,7 @@ def dpa3_case(**overrides: Any) -> tuple:
123125
dpa3_case(exclude_types=[[0, 1]]),
124126
dpa3_case(use_loc_mapping=False),
125127
dpa3_case(add_chg_spin_ebd=True),
128+
dpa3_case(add_chg_spin_ebd=True, default_chg_spin=[5.0, 1.0]),
126129
# Repflow compression branches.
127130
dpa3_case(a_compress_rate=1),
128131
dpa3_case(a_compress_e_rate=2),
@@ -161,6 +164,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple:
161164
dpa3_descriptor_api_case(use_loc_mapping=False),
162165
dpa3_descriptor_api_case(fix_stat_std=0.0),
163166
dpa3_descriptor_api_case(add_chg_spin_ebd=True),
167+
dpa3_descriptor_api_case(add_chg_spin_ebd=True, default_chg_spin=[5.0, 1.0]),
164168
# Repflow compression branches.
165169
dpa3_descriptor_api_case(a_compress_rate=1),
166170
dpa3_descriptor_api_case(a_compress_e_rate=2),
@@ -205,6 +209,7 @@ def data(self) -> dict:
205209
n_multi_edge_message,
206210
precision,
207211
add_chg_spin_ebd,
212+
default_chg_spin,
208213
) = self.param
209214
return {
210215
"ntypes": self.ntypes,
@@ -246,6 +251,7 @@ def data(self) -> dict:
246251
"use_loc_mapping": use_loc_mapping,
247252
"trainable": False,
248253
"add_chg_spin_ebd": add_chg_spin_ebd,
254+
"default_chg_spin": default_chg_spin,
249255
}
250256

251257
@property
@@ -266,6 +272,7 @@ def skip_pt(self) -> bool:
266272
_n_multi_edge_message,
267273
_precision,
268274
_add_chg_spin_ebd,
275+
_default_chg_spin,
269276
) = self.param
270277
return CommonTest.skip_pt
271278

@@ -287,6 +294,7 @@ def skip_pd(self) -> bool:
287294
_n_multi_edge_message,
288295
_precision,
289296
add_chg_spin_ebd,
297+
_default_chg_spin,
290298
) = self.param
291299
return True if add_chg_spin_ebd else CommonTest.skip_pd
292300

@@ -308,6 +316,7 @@ def skip_dp(self) -> bool:
308316
_n_multi_edge_message,
309317
_precision,
310318
_add_chg_spin_ebd,
319+
_default_chg_spin,
311320
) = self.param
312321
return CommonTest.skip_dp
313322

@@ -329,6 +338,7 @@ def skip_tf(self) -> bool:
329338
_n_multi_edge_message,
330339
_precision,
331340
_add_chg_spin_ebd,
341+
_default_chg_spin,
332342
) = self.param
333343
return True
334344

@@ -394,8 +404,8 @@ def setUp(self) -> None:
394404
_n_multi_edge_message,
395405
_precision,
396406
add_chg_spin_ebd,
407+
_default_chg_spin,
397408
) = self.param
398-
# charge_spin for charge=5, spin=1 when add_chg_spin_ebd is True
399409
self.charge_spin = (
400410
np.array([[5, 1]], dtype=GLOBAL_NP_FLOAT_PRECISION)
401411
if add_chg_spin_ebd
@@ -500,6 +510,7 @@ def rtol(self) -> float:
500510
_n_multi_edge_message,
501511
precision,
502512
_add_chg_spin_ebd,
513+
_default_chg_spin,
503514
) = self.param
504515
if precision == "float64":
505516
return 1e-10
@@ -527,6 +538,7 @@ def atol(self) -> float:
527538
_n_multi_edge_message,
528539
precision,
529540
_add_chg_spin_ebd,
541+
_default_chg_spin,
530542
) = self.param
531543
if precision == "float64":
532544
return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786
@@ -563,6 +575,7 @@ def data(self) -> dict:
563575
n_multi_edge_message,
564576
precision,
565577
add_chg_spin_ebd,
578+
default_chg_spin,
566579
) = self.param
567580
return {
568581
"ntypes": self.ntypes,
@@ -604,4 +617,5 @@ def data(self) -> dict:
604617
"use_loc_mapping": use_loc_mapping,
605618
"trainable": False,
606619
"add_chg_spin_ebd": add_chg_spin_ebd,
620+
"default_chg_spin": default_chg_spin,
607621
}

0 commit comments

Comments
 (0)