8181 "n_multi_edge_message" ,
8282 "precision" ,
8383 "add_chg_spin_ebd" ,
84+ "default_chg_spin" ,
8485)
8586
8687
100101 "n_multi_edge_message" : 1 ,
101102 "precision" : "float64" ,
102103 "add_chg_spin_ebd" : False ,
104+ "default_chg_spin" : None ,
103105}
104106
105107
@@ -123,6 +125,7 @@ def dpa3_case(**overrides: Any) -> tuple:
123125 dpa3_case (exclude_types = [[0 , 1 ]]),
124126 dpa3_case (use_loc_mapping = False ),
125127 dpa3_case (add_chg_spin_ebd = True ),
128+ dpa3_case (add_chg_spin_ebd = True , default_chg_spin = [5.0 , 1.0 ]),
126129 # Repflow compression branches.
127130 dpa3_case (a_compress_rate = 1 ),
128131 dpa3_case (a_compress_e_rate = 2 ),
@@ -161,6 +164,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple:
161164 dpa3_descriptor_api_case (use_loc_mapping = False ),
162165 dpa3_descriptor_api_case (fix_stat_std = 0.0 ),
163166 dpa3_descriptor_api_case (add_chg_spin_ebd = True ),
167+ dpa3_descriptor_api_case (add_chg_spin_ebd = True , default_chg_spin = [5.0 , 1.0 ]),
164168 # Repflow compression branches.
165169 dpa3_descriptor_api_case (a_compress_rate = 1 ),
166170 dpa3_descriptor_api_case (a_compress_e_rate = 2 ),
@@ -205,6 +209,7 @@ def data(self) -> dict:
205209 n_multi_edge_message ,
206210 precision ,
207211 add_chg_spin_ebd ,
212+ default_chg_spin ,
208213 ) = self .param
209214 return {
210215 "ntypes" : self .ntypes ,
@@ -246,6 +251,7 @@ def data(self) -> dict:
246251 "use_loc_mapping" : use_loc_mapping ,
247252 "trainable" : False ,
248253 "add_chg_spin_ebd" : add_chg_spin_ebd ,
254+ "default_chg_spin" : default_chg_spin ,
249255 }
250256
251257 @property
@@ -266,6 +272,7 @@ def skip_pt(self) -> bool:
266272 _n_multi_edge_message ,
267273 _precision ,
268274 _add_chg_spin_ebd ,
275+ _default_chg_spin ,
269276 ) = self .param
270277 return CommonTest .skip_pt
271278
@@ -287,6 +294,7 @@ def skip_pd(self) -> bool:
287294 _n_multi_edge_message ,
288295 _precision ,
289296 add_chg_spin_ebd ,
297+ _default_chg_spin ,
290298 ) = self .param
291299 return True if add_chg_spin_ebd else CommonTest .skip_pd
292300
@@ -308,6 +316,7 @@ def skip_dp(self) -> bool:
308316 _n_multi_edge_message ,
309317 _precision ,
310318 _add_chg_spin_ebd ,
319+ _default_chg_spin ,
311320 ) = self .param
312321 return CommonTest .skip_dp
313322
@@ -329,6 +338,7 @@ def skip_tf(self) -> bool:
329338 _n_multi_edge_message ,
330339 _precision ,
331340 _add_chg_spin_ebd ,
341+ _default_chg_spin ,
332342 ) = self .param
333343 return True
334344
@@ -394,8 +404,8 @@ def setUp(self) -> None:
394404 _n_multi_edge_message ,
395405 _precision ,
396406 add_chg_spin_ebd ,
407+ _default_chg_spin ,
397408 ) = self .param
398- # charge_spin for charge=5, spin=1 when add_chg_spin_ebd is True
399409 self .charge_spin = (
400410 np .array ([[5 , 1 ]], dtype = GLOBAL_NP_FLOAT_PRECISION )
401411 if add_chg_spin_ebd
@@ -500,6 +510,7 @@ def rtol(self) -> float:
500510 _n_multi_edge_message ,
501511 precision ,
502512 _add_chg_spin_ebd ,
513+ _default_chg_spin ,
503514 ) = self .param
504515 if precision == "float64" :
505516 return 1e-10
@@ -527,6 +538,7 @@ def atol(self) -> float:
527538 _n_multi_edge_message ,
528539 precision ,
529540 _add_chg_spin_ebd ,
541+ _default_chg_spin ,
530542 ) = self .param
531543 if precision == "float64" :
532544 return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786
@@ -563,6 +575,7 @@ def data(self) -> dict:
563575 n_multi_edge_message ,
564576 precision ,
565577 add_chg_spin_ebd ,
578+ default_chg_spin ,
566579 ) = self .param
567580 return {
568581 "ntypes" : self .ntypes ,
@@ -604,4 +617,5 @@ def data(self) -> dict:
604617 "use_loc_mapping" : use_loc_mapping ,
605618 "trainable" : False ,
606619 "add_chg_spin_ebd" : add_chg_spin_ebd ,
620+ "default_chg_spin" : default_chg_spin ,
607621 }
0 commit comments