Skip to content

Commit d9a4db4

Browse files
committed
fix type map
1 parent f91ba7d commit d9a4db4

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,26 @@ def __init__(
222222
self._natoms = sum(self._natoms_per_type)
223223
self._ntypes = len(type_map)
224224

225+
# Build type remapping if LMDB's type_map differs from model's type_map
226+
lmdb_type_map = meta.get("type_map")
227+
self._lmdb_type_map = lmdb_type_map
228+
self._type_remap: np.ndarray | None = None
229+
if lmdb_type_map is not None and list(lmdb_type_map) != list(type_map):
230+
# Build remap: lmdb_type_idx -> model_type_idx
231+
remap = np.empty(len(lmdb_type_map), dtype=np.int32)
232+
for i, name in enumerate(lmdb_type_map):
233+
if name not in type_map:
234+
raise ValueError(
235+
f"Element '{name}' in LMDB type_map {lmdb_type_map} "
236+
f"not found in model type_map {type_map}"
237+
)
238+
remap[i] = type_map.index(name)
239+
self._type_remap = remap
240+
log.info(
241+
f"Type remapping: LMDB {lmdb_type_map} -> model {type_map}, "
242+
f"remap={list(remap)}"
243+
)
244+
225245
# Persistent read-only transaction for __getitem__ (avoids per-read overhead).
226246
# Safe because we use num_workers=0 in DataLoader.
227247
self._txn = self._env.begin()
@@ -346,6 +366,9 @@ def __getitem__(self, index: int) -> dict[str, Any]:
346366
)
347367
if "atype" in frame and isinstance(frame["atype"], np.ndarray):
348368
frame["atype"] = frame["atype"].reshape(-1).astype(np.int64)
369+
# Remap atom types from LMDB's type_map to model's type_map
370+
if self._type_remap is not None:
371+
frame["atype"] = self._type_remap[frame["atype"]].astype(np.int64)
349372
if "virial" in frame and isinstance(frame["virial"], np.ndarray):
350373
frame["virial"] = (
351374
frame["virial"].reshape(9).astype(self._resolve_dtype("virial"))

0 commit comments

Comments
 (0)