File tree Expand file tree Collapse file tree
source/tests/consistent/loss Expand file tree Collapse file tree Original file line number Diff line number Diff line change 5757
5858
5959@parameterized (
60- (False , True ), # use_huber
60+ (False , False ), # huber, enable_atom_ener_coeff
61+ (True , False ),
62+ (False , True ),
63+ (True , True ),
6164)
6265class TestEner (CommonTest , LossTest , unittest .TestCase ):
6366 @property
6467 def data (self ) -> dict :
65- (use_huber ,) = self .param
68+ (use_huber , enable_atom_ener_coeff ) = self .param
6669 return {
6770 "start_pref_e" : 0.02 ,
6871 "limit_pref_e" : 1.0 ,
@@ -75,6 +78,7 @@ def data(self) -> dict:
7578 "start_pref_pf" : 1.0 if not use_huber else 0.0 ,
7679 "limit_pref_pf" : 1.0 if not use_huber else 0.0 ,
7780 "use_huber" : use_huber ,
81+ "enable_atom_ener_coeff" : enable_atom_ener_coeff ,
7882 }
7983
8084 skip_tf = CommonTest .skip_tf
@@ -124,11 +128,13 @@ def setUp(self) -> None:
124128 self .natoms ,
125129 )
126130 ),
131+ "atom_ener_coeff" : rng .random ((self .nframes , self .natoms )),
127132 "atom_pref" : np .ones ((self .nframes , self .natoms , 3 )),
128133 "find_energy" : 1.0 ,
129134 "find_force" : 1.0 ,
130135 "find_virial" : 1.0 ,
131136 "find_atom_ener" : 1.0 ,
137+ "find_atom_ener_coeff" : 1.0 ,
132138 "find_atom_pref" : 1.0 ,
133139 }
134140
You can’t perform that action at this time.
0 commit comments