Skip to content

Commit 750cbb2

Browse files
author
Han Wang
committed
add ut
1 parent f085332 commit 750cbb2

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

source/tests/consistent/loss/test_ener.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,15 @@
5757

5858

5959
@parameterized(
60-
(False, True), # use_huber
60+
(False, False), # huber, enable_atom_ener_coeff
61+
(True, False),
62+
(False, True),
63+
(True, True),
6164
)
6265
class TestEner(CommonTest, LossTest, unittest.TestCase):
6366
@property
6467
def data(self) -> dict:
65-
(use_huber,) = self.param
68+
(use_huber, enable_atom_ener_coeff) = self.param
6669
return {
6770
"start_pref_e": 0.02,
6871
"limit_pref_e": 1.0,
@@ -75,6 +78,7 @@ def data(self) -> dict:
7578
"start_pref_pf": 1.0 if not use_huber else 0.0,
7679
"limit_pref_pf": 1.0 if not use_huber else 0.0,
7780
"use_huber": use_huber,
81+
"enable_atom_ener_coeff": enable_atom_ener_coeff,
7882
}
7983

8084
skip_tf = CommonTest.skip_tf
@@ -124,11 +128,13 @@ def setUp(self) -> None:
124128
self.natoms,
125129
)
126130
),
131+
"atom_ener_coeff": rng.random((self.nframes, self.natoms)),
127132
"atom_pref": np.ones((self.nframes, self.natoms, 3)),
128133
"find_energy": 1.0,
129134
"find_force": 1.0,
130135
"find_virial": 1.0,
131136
"find_atom_ener": 1.0,
137+
"find_atom_ener_coeff": 1.0,
132138
"find_atom_pref": 1.0,
133139
}
134140

0 commit comments

Comments
 (0)