@@ -709,5 +709,159 @@ def test_sampling_large_dataset(self):
709709 tmpdir .cleanup ()
710710
711711
712+ def _create_lmdb_with_extra_keys (
713+ path : str , nframes : int = 5 , natoms : int = 6 , extra_keys : dict | None = None
714+ ) -> str :
715+ """Create a test LMDB with extra per-frame keys (e.g. atom_pref, fparam).
716+
717+ Parameters
718+ ----------
719+ extra_keys : dict
720+ key -> (shape_fn, dtype) where shape_fn(natoms) returns the array shape.
721+ Example: {"atom_pref": (lambda n: (n,), np.float64)}
722+ """
723+ n_type0 = max (1 , natoms // 3 )
724+ n_type1 = natoms - n_type0
725+ extra_keys = extra_keys or {}
726+ env = lmdb .open (path , map_size = 10 * 1024 * 1024 )
727+ with env .begin (write = True ) as txn :
728+ meta = {
729+ "nframes" : nframes ,
730+ "frame_idx_fmt" : "012d" ,
731+ "type_map" : ["O" , "H" ],
732+ "system_info" : {"natoms" : [n_type0 , n_type1 ]},
733+ }
734+ txn .put (b"__metadata__" , msgpack .packb (meta , use_bin_type = True ))
735+ rng = np .random .RandomState (0 )
736+ for i in range (nframes ):
737+ frame = _make_frame (natoms = natoms , seed = i )
738+ for ek , (shape_fn , dtype ) in extra_keys .items ():
739+ arr = rng .rand (* shape_fn (natoms )).astype (dtype )
740+ frame [ek ] = {
741+ "type" : str (arr .dtype ),
742+ "shape" : list (arr .shape ),
743+ "data" : arr .tobytes (),
744+ }
745+ txn .put (
746+ format (i , "012d" ).encode (),
747+ msgpack .packb (frame , use_bin_type = True ),
748+ )
749+ env .close ()
750+ return path
751+
752+
753+ # ============================================================
754+ # Dynamic find_* and repeat tests
755+ # ============================================================
756+
757+
758+ class TestDynamicKeysAndRepeat (unittest .TestCase ):
759+ """Test auto-discovery of find_* flags and repeat handling."""
760+
761+ @classmethod
762+ def setUpClass (cls ):
763+ cls ._tmpdir = tempfile .TemporaryDirectory ()
764+ cls ._natoms = 6
765+ cls ._nframes = 5
766+ cls ._lmdb_path = _create_lmdb_with_extra_keys (
767+ f"{ cls ._tmpdir .name } /extra.lmdb" ,
768+ nframes = cls ._nframes ,
769+ natoms = cls ._natoms ,
770+ extra_keys = {
771+ "atom_pref" : (lambda n : (n ,), np .float64 ),
772+ "fparam" : (lambda n : (3 ,), np .float64 ),
773+ },
774+ )
775+ cls ._type_map = ["O" , "H" ]
776+
777+ @classmethod
778+ def tearDownClass (cls ):
779+ cls ._tmpdir .cleanup ()
780+
781+ # --- LmdbDataReader ---
782+
783+ def test_reader_find_flags_auto_detected (self ):
784+ """Extra keys in frame get find_*=1.0 automatically."""
785+ reader = LmdbDataReader (self ._lmdb_path , self ._type_map )
786+ frame = reader [0 ]
787+ self .assertEqual (frame ["find_atom_pref" ], np .float32 (1.0 ))
788+ self .assertEqual (frame ["find_fparam" ], np .float32 (1.0 ))
789+ self .assertEqual (frame ["find_energy" ], np .float32 (1.0 ))
790+ # Keys not in frame get find_*=0.0
791+ self .assertEqual (frame ["find_aparam" ], np .float32 (0.0 ))
792+ self .assertEqual (frame ["find_spin" ], np .float32 (0.0 ))
793+
794+ def test_reader_repeat_applied (self ):
795+ """DataRequirementItem with repeat=3 expands atom_pref from (natoms,) to (natoms*3,)."""
796+ from deepmd .utils .data import (
797+ DataRequirementItem ,
798+ )
799+
800+ reader = LmdbDataReader (self ._lmdb_path , self ._type_map )
801+ reader .add_data_requirement (
802+ [
803+ DataRequirementItem (
804+ "atom_pref" ,
805+ ndof = 1 ,
806+ atomic = True ,
807+ must = False ,
808+ high_prec = False ,
809+ repeat = 3 ,
810+ ),
811+ ]
812+ )
813+ frame = reader [0 ]
814+ self .assertEqual (frame ["atom_pref" ].shape , (self ._natoms * 3 ,))
815+
816+ def test_reader_repeat_default_fill (self ):
817+ """Missing key with repeat fills correct shape."""
818+ from deepmd .utils .data import (
819+ DataRequirementItem ,
820+ )
821+
822+ reader = LmdbDataReader (self ._lmdb_path , self ._type_map )
823+ reader .add_data_requirement (
824+ [
825+ DataRequirementItem (
826+ "drdq" , ndof = 6 , atomic = True , must = False , high_prec = False , repeat = 2
827+ ),
828+ ]
829+ )
830+ frame = reader [0 ]
831+ self .assertEqual (frame ["find_drdq" ], np .float32 (0.0 ))
832+ self .assertEqual (frame ["drdq" ].shape , (self ._natoms * 6 * 2 ,))
833+
834+ # --- LmdbTestData ---
835+
836+ def test_testdata_find_flags_auto_detected (self ):
837+ """LmdbTestData.get_test() discovers extra keys dynamically."""
838+ td = LmdbTestData (self ._lmdb_path , type_map = self ._type_map , shuffle_test = False )
839+ result = td .get_test ()
840+ self .assertEqual (result ["find_atom_pref" ], 1.0 )
841+ self .assertEqual (result ["find_fparam" ], 1.0 )
842+ self .assertIn ("atom_pref" , result )
843+ self .assertIn ("fparam" , result )
844+
845+ def test_testdata_repeat_applied (self ):
846+ """LmdbTestData respects repeat=3 for atom_pref."""
847+ td = LmdbTestData (self ._lmdb_path , type_map = self ._type_map , shuffle_test = False )
848+ td .add ("atom_pref" , 1 , atomic = True , must = False , high_prec = False , repeat = 3 )
849+ result = td .get_test ()
850+ self .assertEqual (
851+ result ["atom_pref" ].shape ,
852+ (self ._nframes , self ._natoms * 3 ),
853+ )
854+
855+ def test_testdata_missing_key_not_found (self ):
856+ """Keys absent from LMDB frames get find_*=0.0 in get_test()."""
857+ tmpdir = tempfile .TemporaryDirectory ()
858+ path = _create_lmdb (f"{ tmpdir .name } /plain.lmdb" , nframes = 3 , natoms = 6 )
859+ td = LmdbTestData (path , type_map = ["O" , "H" ], shuffle_test = False )
860+ result = td .get_test ()
861+ # atom_pref is not in the plain LMDB
862+ self .assertEqual (result .get ("find_atom_pref" , 0.0 ), 0.0 )
863+ tmpdir .cleanup ()
864+
865+
712866if __name__ == "__main__" :
713867 unittest .main ()
0 commit comments