@@ -455,5 +455,47 @@ def tearDown(self):
455455 shutil .rmtree ("tmp.deepmd.mixed.single" )
456456
457457
458+ class TestMixedSystemWithFparamAparam (unittest .TestCase , CompLabeledSys , IsNoPBC ):
459+ def setUp (self ):
460+ self .places = 6
461+ self .e_places = 6
462+ self .f_places = 6
463+ self .v_places = 6
464+
465+ system_1 = dpdata .LabeledSystem (
466+ "gaussian/methane.gaussianlog" , fmt = "gaussian/log"
467+ )
468+
469+ tmp_data = system_1 .data .copy ()
470+ nframes = tmp_data ["coords" ].shape [0 ]
471+ natoms = tmp_data ["atom_types" ].shape [0 ]
472+
473+ tmp_data ["fparam" ] = np .random .random ([nframes , 2 ])
474+ tmp_data ["aparam" ] = np .random .random ([nframes , natoms , 3 ])
475+
476+ self .system_1 = dpdata .LabeledSystem (data = tmp_data )
477+
478+ self .system_1 .to ("deepmd/npy/mixed" , "tmp.deepmd.fparam.aparam" )
479+ self .system_2 = dpdata .LabeledSystem ("tmp.deepmd.fparam.aparam" , fmt = "deepmd/npy/mixed" )
480+
481+ def tearDown (self ):
482+ if os .path .exists ("tmp.deepmd.fparam.aparam" ):
483+ shutil .rmtree ("tmp.deepmd.fparam.aparam" )
484+
485+ def test_fparam_exists (self ):
486+ self .assertTrue ("fparam" in self .system_1 .data )
487+ self .assertTrue ("fparam" in self .system_2 .data )
488+ np .testing .assert_almost_equal (
489+ self .system_1 .data ["fparam" ], self .system_2 .data ["fparam" ], decimal = self .places
490+ )
491+
492+ def test_aparam_exists (self ):
493+ self .assertTrue ("aparam" in self .system_1 .data )
494+ self .assertTrue ("aparam" in self .system_2 .data )
495+ np .testing .assert_almost_equal (
496+ self .system_1 .data ["aparam" ], self .system_2 .data ["aparam" ], decimal = self .places
497+ )
498+
499+
458500if __name__ == "__main__" :
459501 unittest .main ()
0 commit comments