@@ -222,6 +222,26 @@ def __init__(
222222 self ._natoms = sum (self ._natoms_per_type )
223223 self ._ntypes = len (type_map )
224224
225+ # Build type remapping if LMDB's type_map differs from model's type_map
226+ lmdb_type_map = meta .get ("type_map" )
227+ self ._lmdb_type_map = lmdb_type_map
228+ self ._type_remap : np .ndarray | None = None
229+ if lmdb_type_map is not None and list (lmdb_type_map ) != list (type_map ):
230+ # Build remap: lmdb_type_idx -> model_type_idx
231+ remap = np .empty (len (lmdb_type_map ), dtype = np .int32 )
232+ for i , name in enumerate (lmdb_type_map ):
233+ if name not in type_map :
234+ raise ValueError (
235+ f"Element '{ name } ' in LMDB type_map { lmdb_type_map } "
236+ f"not found in model type_map { type_map } "
237+ )
238+ remap [i ] = type_map .index (name )
239+ self ._type_remap = remap
240+ log .info (
241+ f"Type remapping: LMDB { lmdb_type_map } -> model { type_map } , "
242+ f"remap={ list (remap )} "
243+ )
244+
225245 # Persistent read-only transaction for __getitem__ (avoids per-read overhead).
226246 # Safe because we use num_workers=0 in DataLoader.
227247 self ._txn = self ._env .begin ()
@@ -346,6 +366,9 @@ def __getitem__(self, index: int) -> dict[str, Any]:
346366 )
347367 if "atype" in frame and isinstance (frame ["atype" ], np .ndarray ):
348368 frame ["atype" ] = frame ["atype" ].reshape (- 1 ).astype (np .int64 )
369+ # Remap atom types from LMDB's type_map to model's type_map
370+ if self ._type_remap is not None :
371+ frame ["atype" ] = self ._type_remap [frame ["atype" ]].astype (np .int64 )
349372 if "virial" in frame and isinstance (frame ["virial" ], np .ndarray ):
350373 frame ["virial" ] = (
351374 frame ["virial" ].reshape (9 ).astype (self ._resolve_dtype ("virial" ))
0 commit comments