Skip to content

Commit 78e2d5f

Browse files
committed
add ut for extra keys
1 parent 24b9060 commit 78e2d5f

2 files changed

Lines changed: 157 additions & 3 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,13 +1289,13 @@ def _stack_frames(
12891289
np.stack(atypes) if atypes else np.zeros((0, natoms), dtype=np.int64)
12901290
)
12911291

1292-
# Dynamically discover all data keys present in frames, plus
1292+
# Dynamically discover all data keys from the first frame, plus
12931293
# any registered requirements. Structural keys (coord, box, type)
12941294
# are excluded — they are already handled above.
12951295
_structural_keys = frozenset({"coord", "box", "atype"})
12961296
all_keys: dict[str, dict[str, Any]] = {}
1297-
for f in frames:
1298-
for fk in f:
1297+
if frames:
1298+
for fk in frames[0]:
12991299
if fk in _structural_keys or fk.startswith("find_"):
13001300
continue
13011301
if fk not in all_keys:

source/tests/common/dpmodel/test_lmdb_data.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,5 +709,159 @@ def test_sampling_large_dataset(self):
709709
tmpdir.cleanup()
710710

711711

712+
def _create_lmdb_with_extra_keys(
713+
path: str, nframes: int = 5, natoms: int = 6, extra_keys: dict | None = None
714+
) -> str:
715+
"""Create a test LMDB with extra per-frame keys (e.g. atom_pref, fparam).
716+
717+
Parameters
718+
----------
719+
extra_keys : dict
720+
key -> (shape_fn, dtype) where shape_fn(natoms) returns the array shape.
721+
Example: {"atom_pref": (lambda n: (n,), np.float64)}
722+
"""
723+
n_type0 = max(1, natoms // 3)
724+
n_type1 = natoms - n_type0
725+
extra_keys = extra_keys or {}
726+
env = lmdb.open(path, map_size=10 * 1024 * 1024)
727+
with env.begin(write=True) as txn:
728+
meta = {
729+
"nframes": nframes,
730+
"frame_idx_fmt": "012d",
731+
"type_map": ["O", "H"],
732+
"system_info": {"natoms": [n_type0, n_type1]},
733+
}
734+
txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True))
735+
rng = np.random.RandomState(0)
736+
for i in range(nframes):
737+
frame = _make_frame(natoms=natoms, seed=i)
738+
for ek, (shape_fn, dtype) in extra_keys.items():
739+
arr = rng.rand(*shape_fn(natoms)).astype(dtype)
740+
frame[ek] = {
741+
"type": str(arr.dtype),
742+
"shape": list(arr.shape),
743+
"data": arr.tobytes(),
744+
}
745+
txn.put(
746+
format(i, "012d").encode(),
747+
msgpack.packb(frame, use_bin_type=True),
748+
)
749+
env.close()
750+
return path
751+
752+
753+
# ============================================================
754+
# Dynamic find_* and repeat tests
755+
# ============================================================
756+
757+
758+
class TestDynamicKeysAndRepeat(unittest.TestCase):
759+
"""Test auto-discovery of find_* flags and repeat handling."""
760+
761+
@classmethod
762+
def setUpClass(cls):
763+
cls._tmpdir = tempfile.TemporaryDirectory()
764+
cls._natoms = 6
765+
cls._nframes = 5
766+
cls._lmdb_path = _create_lmdb_with_extra_keys(
767+
f"{cls._tmpdir.name}/extra.lmdb",
768+
nframes=cls._nframes,
769+
natoms=cls._natoms,
770+
extra_keys={
771+
"atom_pref": (lambda n: (n,), np.float64),
772+
"fparam": (lambda n: (3,), np.float64),
773+
},
774+
)
775+
cls._type_map = ["O", "H"]
776+
777+
@classmethod
778+
def tearDownClass(cls):
779+
cls._tmpdir.cleanup()
780+
781+
# --- LmdbDataReader ---
782+
783+
def test_reader_find_flags_auto_detected(self):
784+
"""Extra keys in frame get find_*=1.0 automatically."""
785+
reader = LmdbDataReader(self._lmdb_path, self._type_map)
786+
frame = reader[0]
787+
self.assertEqual(frame["find_atom_pref"], np.float32(1.0))
788+
self.assertEqual(frame["find_fparam"], np.float32(1.0))
789+
self.assertEqual(frame["find_energy"], np.float32(1.0))
790+
# Keys not in frame get find_*=0.0
791+
self.assertEqual(frame["find_aparam"], np.float32(0.0))
792+
self.assertEqual(frame["find_spin"], np.float32(0.0))
793+
794+
def test_reader_repeat_applied(self):
795+
"""DataRequirementItem with repeat=3 expands atom_pref from (natoms,) to (natoms*3,)."""
796+
from deepmd.utils.data import (
797+
DataRequirementItem,
798+
)
799+
800+
reader = LmdbDataReader(self._lmdb_path, self._type_map)
801+
reader.add_data_requirement(
802+
[
803+
DataRequirementItem(
804+
"atom_pref",
805+
ndof=1,
806+
atomic=True,
807+
must=False,
808+
high_prec=False,
809+
repeat=3,
810+
),
811+
]
812+
)
813+
frame = reader[0]
814+
self.assertEqual(frame["atom_pref"].shape, (self._natoms * 3,))
815+
816+
def test_reader_repeat_default_fill(self):
817+
"""Missing key with repeat fills correct shape."""
818+
from deepmd.utils.data import (
819+
DataRequirementItem,
820+
)
821+
822+
reader = LmdbDataReader(self._lmdb_path, self._type_map)
823+
reader.add_data_requirement(
824+
[
825+
DataRequirementItem(
826+
"drdq", ndof=6, atomic=True, must=False, high_prec=False, repeat=2
827+
),
828+
]
829+
)
830+
frame = reader[0]
831+
self.assertEqual(frame["find_drdq"], np.float32(0.0))
832+
self.assertEqual(frame["drdq"].shape, (self._natoms * 6 * 2,))
833+
834+
# --- LmdbTestData ---
835+
836+
def test_testdata_find_flags_auto_detected(self):
837+
"""LmdbTestData.get_test() discovers extra keys dynamically."""
838+
td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False)
839+
result = td.get_test()
840+
self.assertEqual(result["find_atom_pref"], 1.0)
841+
self.assertEqual(result["find_fparam"], 1.0)
842+
self.assertIn("atom_pref", result)
843+
self.assertIn("fparam", result)
844+
845+
def test_testdata_repeat_applied(self):
846+
"""LmdbTestData respects repeat=3 for atom_pref."""
847+
td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False)
848+
td.add("atom_pref", 1, atomic=True, must=False, high_prec=False, repeat=3)
849+
result = td.get_test()
850+
self.assertEqual(
851+
result["atom_pref"].shape,
852+
(self._nframes, self._natoms * 3),
853+
)
854+
855+
def test_testdata_missing_key_not_found(self):
856+
"""Keys absent from LMDB frames get find_*=0.0 in get_test()."""
857+
tmpdir = tempfile.TemporaryDirectory()
858+
path = _create_lmdb(f"{tmpdir.name}/plain.lmdb", nframes=3, natoms=6)
859+
td = LmdbTestData(path, type_map=["O", "H"], shuffle_test=False)
860+
result = td.get_test()
861+
# atom_pref is not in the plain LMDB
862+
self.assertEqual(result.get("find_atom_pref", 0.0), 0.0)
863+
tmpdir.cleanup()
864+
865+
712866
if __name__ == "__main__":
713867
unittest.main()

0 commit comments

Comments
 (0)