@@ -716,14 +716,47 @@ def __init__(
716716 self .nframes , self ._frame_fmt , self ._natoms_per_type = _parse_metadata (meta )
717717 self ._natoms = sum (self ._natoms_per_type )
718718
719+ # Build type remapping if LMDB's type_map differs from model's type_map
720+ lmdb_type_map = meta .get ("type_map" )
721+ self ._lmdb_type_map = lmdb_type_map
722+ self ._type_remap : np .ndarray | None = None
723+ if (
724+ lmdb_type_map is not None
725+ and self ._type_map
726+ and list (lmdb_type_map ) != list (self ._type_map )
727+ ):
728+ remap = np .empty (len (lmdb_type_map ), dtype = np .int32 )
729+ for i , name in enumerate (lmdb_type_map ):
730+ if name not in self ._type_map :
731+ raise ValueError (
732+ f"Element '{ name } ' in LMDB type_map { lmdb_type_map } "
733+ f"not found in model type_map { self ._type_map } "
734+ )
735+ remap [i ] = self ._type_map .index (name )
736+ self ._type_remap = remap
737+ log .info (
738+ f"LmdbTestData type remapping: LMDB { lmdb_type_map } -> "
739+ f"model { self ._type_map } , remap={ list (remap )} "
740+ )
741+
719742 # Read all frames
720743 self ._frames : list [dict [str , Any ]] = []
721744 with self ._env .begin () as txn :
722745 for i in range (self .nframes ):
723746 key = format (i , self ._frame_fmt ).encode ()
724747 raw = txn .get (key )
725748 if raw is not None :
726- self ._frames .append (_remap_keys (_decode_frame (raw )))
749+ frame = _remap_keys (_decode_frame (raw ))
750+ # Apply type remapping to atype
751+ if (
752+ self ._type_remap is not None
753+ and "atype" in frame
754+ and isinstance (frame ["atype" ], np .ndarray )
755+ ):
756+ frame ["atype" ] = self ._type_remap [
757+ frame ["atype" ].reshape (- 1 )
758+ ].astype (np .int64 )
759+ self ._frames .append (frame )
727760
728761 # Shuffle if requested
729762 if shuffle_test :
0 commit comments