Skip to content

Commit f461f43

Browse files
committed
fix: UT
1 parent 420bef7 commit f461f43

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

source/tests/pt/model/test_linear_atomic_model_stat.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,12 @@ def test_linear_atomic_model_stat_with_bias(self) -> None:
233233
linear_model.compute_or_load_out_stat(
234234
self.merged_output_stat, stat_file_path=self.stat_file_path
235235
)
236-
# bias applied to sub atomic models.
237236
ener_bias = np.array([1.0, 3.0]).reshape(2, 1)
238-
linear_ret = []
239-
for idx, md in enumerate(linear_model.models):
240-
ret = md.forward_common_atomic(*args)
241-
ret = to_numpy_array(ret["energy"])
242-
linear_ret.append(ret_no_bias[idx] + ener_bias[at])
243-
np.testing.assert_almost_equal((ret_no_bias[idx] + ener_bias[at]), ret)
237+
ret = to_numpy_array(linear_model.forward_common_atomic(*args)["energy"])
238+
np.testing.assert_almost_equal((ret0 + ener_bias[at]), ret)
239+
244240

245241
# linear model not adding bias again
246242
ret1 = linear_model.forward_common_atomic(*args)
247243
ret1 = to_numpy_array(ret1["energy"])
248-
np.testing.assert_almost_equal(np.mean(np.stack(linear_ret), axis=0), ret1)
244+
np.testing.assert_almost_equal(ret, ret1)

0 commit comments

Comments
 (0)