Skip to content

Commit 2bad443

Browse files
committed
feat: try add UT
1 parent c04d685 commit 2bad443

1 file changed

Lines changed: 42 additions & 0 deletions

File tree

tests/test_deepmd_mixed.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,5 +455,47 @@ def tearDown(self):
455455
shutil.rmtree("tmp.deepmd.mixed.single")
456456

457457

458+
class TestMixedSystemWithFparamAparam(unittest.TestCase, CompLabeledSys, IsNoPBC):
459+
def setUp(self):
460+
self.places = 6
461+
self.e_places = 6
462+
self.f_places = 6
463+
self.v_places = 6
464+
465+
system_1 = dpdata.LabeledSystem(
466+
"gaussian/methane.gaussianlog", fmt="gaussian/log"
467+
)
468+
469+
tmp_data = system_1.data.copy()
470+
nframes = tmp_data["coords"].shape[0]
471+
natoms = tmp_data["atom_types"].shape[0]
472+
473+
tmp_data["fparam"] = np.random.random([nframes, 2])
474+
tmp_data["aparam"] = np.random.random([nframes, natoms, 3])
475+
476+
self.system_1 = dpdata.LabeledSystem(data=tmp_data)
477+
478+
self.system_1.to("deepmd/npy/mixed", "tmp.deepmd.fparam.aparam")
479+
self.system_2 = dpdata.LabeledSystem("tmp.deepmd.fparam.aparam", fmt="deepmd/npy/mixed")
480+
481+
def tearDown(self):
482+
if os.path.exists("tmp.deepmd.fparam.aparam"):
483+
shutil.rmtree("tmp.deepmd.fparam.aparam")
484+
485+
def test_fparam_exists(self):
486+
self.assertTrue("fparam" in self.system_1.data)
487+
self.assertTrue("fparam" in self.system_2.data)
488+
np.testing.assert_almost_equal(
489+
self.system_1.data["fparam"], self.system_2.data["fparam"], decimal=self.places
490+
)
491+
492+
def test_aparam_exists(self):
493+
self.assertTrue("aparam" in self.system_1.data)
494+
self.assertTrue("aparam" in self.system_2.data)
495+
np.testing.assert_almost_equal(
496+
self.system_1.data["aparam"], self.system_2.data["aparam"], decimal=self.places
497+
)
498+
499+
458500
if __name__ == "__main__":
459501
unittest.main()

0 commit comments

Comments
 (0)