Skip to content

Commit 86df0ca

Browse files
committed
fixup lmdb
1 parent c3fc8e8 commit 86df0ca

3 files changed

Lines changed: 119 additions & 44 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def _open_lmdb(path: str) -> lmdb.Environment:
6666
env, refcount = entry
6767
_ENV_CACHE[resolved] = (env, refcount + 1)
6868
return env
69-
env = lmdb.open(path, readonly=True, lock=False, readahead=False, meminit=False)
69+
# ``readahead=True`` lets the kernel batch-prefetch B+tree pages; this is
70+
# cheap on local SSDs and a major win on networked filesystems (vepfs /
71+
# NFS / Lustre) where each uncoalesced 4 KB page fault costs a full RPC.
72+
env = lmdb.open(path, readonly=True, lock=False, readahead=True, meminit=False)
7073
_ENV_CACHE[resolved] = (env, 1)
7174
return env
7275

@@ -298,33 +301,34 @@ def __init__(
298301

299302
# Scan per-frame nloc only when needed for same-nloc batching.
300303
# For mixed_batch=True, skip the scan entirely (future: padding handles it).
301-
# We keep _frame_nlocs / _frame_system_ids indexable by the *original*
302-
# LMDB frame index even after filter:N: entries for dropped frames
303-
# simply never get referenced because _nloc_groups / _system_groups
304-
# no longer reference them.
304+
# ``orig_frame_nlocs`` / ``orig_frame_system_ids`` are indexed by the
305+
# *original* LMDB frame index. After a potential ``filter:N`` drop we
306+
# rebuild ``self._frame_nlocs`` / ``self._frame_system_ids`` so they
307+
# are parallel arrays over the *dataset* index space (0..len(self));
308+
# the dataset-to-original mapping lives in ``self._retained_keys``.
305309
if not mixed_batch:
306310
# Fast path: use pre-computed frame_nlocs from metadata if available.
307311
# Falls back to scanning each frame's atom_types shape (~10 us/frame).
308312
meta_nlocs = meta.get("frame_nlocs")
309313
if meta_nlocs is not None:
310-
self._frame_nlocs = [int(n) for n in meta_nlocs]
314+
orig_frame_nlocs = [int(n) for n in meta_nlocs]
311315
else:
312-
self._frame_nlocs = _scan_frame_nlocs(
316+
orig_frame_nlocs = _scan_frame_nlocs(
313317
self._env, self.nframes, self._frame_fmt, self._natoms
314318
)
315319
else:
316-
self._frame_nlocs = []
320+
orig_frame_nlocs = []
317321

318-
# Parse frame_system_ids for auto_prob support. _nsystems must stay at
319-
# ``max(original_sid) + 1`` even after filter:N so that user-facing
322+
# Parse frame_system_ids for auto_prob support. ``_nsystems`` must stay
323+
# at ``max(original_sid) + 1`` even after filter:N so that user-facing
320324
# auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
321325
# keeps its meaning across filter thresholds.
322326
meta_sys_ids = meta.get("frame_system_ids")
323327
if meta_sys_ids is not None:
324-
self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
325-
self._nsystems = max(self._frame_system_ids) + 1
328+
orig_frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
329+
self._nsystems = max(orig_frame_system_ids) + 1
326330
else:
327-
self._frame_system_ids = None
331+
orig_frame_system_ids = None
328332
self._nsystems = 1
329333

330334
# Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
@@ -353,47 +357,66 @@ def __init__(
353357
# ``filter:N`` every frame is retained. ``mixed_batch=True`` has no
354358
# per-frame nloc info to filter against, so filter:N is a no-op there.
355359
if self._filter_rule is not None and not mixed_batch:
356-
retained_indices = [
357-
i for i, n in enumerate(self._frame_nlocs) if n <= self._filter_rule
360+
retained_keys = [
361+
i for i, n in enumerate(orig_frame_nlocs) if n <= self._filter_rule
358362
]
359-
n_dropped = self.nframes - len(retained_indices)
363+
n_dropped = self.nframes - len(retained_keys)
360364
if n_dropped > 0:
361365
log.info(
362366
f"LMDB filter:{self._filter_rule} drops {n_dropped}/"
363367
f"{self.nframes} frames with nloc > {self._filter_rule} "
364368
f"({self.lmdb_path})."
365369
)
366370
else:
367-
retained_indices = list(range(self.nframes))
371+
retained_keys = list(range(self.nframes))
372+
373+
# Dataset-index → original LMDB frame key. ``__getitem__`` looks up
374+
# this table so that ``reader[i]`` is a valid LMDB read for every
375+
# ``0 <= i < len(reader)``, no matter how many frames were filtered.
376+
self._retained_keys: list[int] = retained_keys
377+
378+
# Re-key _frame_nlocs / _frame_system_ids into the dataset-index
379+
# space so that every downstream consumer (nloc_groups, system_groups,
380+
# SameNlocBatchSampler, _expand_indices_by_blocks) operates in a
381+
# single, self-consistent indexing scheme.
382+
if not mixed_batch:
383+
self._frame_nlocs = [orig_frame_nlocs[k] for k in retained_keys]
384+
else:
385+
self._frame_nlocs = []
386+
387+
if orig_frame_system_ids is not None:
388+
self._frame_system_ids: list[int] | None = [
389+
orig_frame_system_ids[k] for k in retained_keys
390+
]
391+
else:
392+
self._frame_system_ids = None
368393

369-
# Group retained frames by nloc. _nloc_groups only contains nlocs
370-
# that passed the filter; its values stay as *original* LMDB frame
371-
# indices so __getitem__(index) keeps reading the right LMDB key.
394+
# Group retained frames by nloc using dataset indices (0..len-1).
372395
if not mixed_batch:
373396
self._nloc_groups: dict[int, list[int]] = {}
374-
for idx in retained_indices:
375-
self._nloc_groups.setdefault(self._frame_nlocs[idx], []).append(idx)
397+
for ds_idx, nloc in enumerate(self._frame_nlocs):
398+
self._nloc_groups.setdefault(nloc, []).append(ds_idx)
376399
else:
377400
self._nloc_groups = {}
378401

379-
# Group retained frames by system id. _system_nframes is indexed by
380-
# *original* sid and stays length _nsystems even if some systems are
381-
# fully dropped — those entries are simply zero so auto_prob block
382-
# slicing still parses predictably.
402+
# Group retained frames by original system id; the sid numbering is
403+
# preserved (no compression) so user-facing auto_prob slices stay
404+
# meaningful across filter thresholds. Fully-dropped systems appear
405+
# as zero-frame entries in ``_system_nframes``.
383406
if self._frame_system_ids is not None:
384407
self._system_groups: dict[int, list[int]] = {}
385-
for idx in retained_indices:
386-
sid = self._frame_system_ids[idx]
387-
self._system_groups.setdefault(sid, []).append(idx)
408+
for ds_idx, sid in enumerate(self._frame_system_ids):
409+
self._system_groups.setdefault(sid, []).append(ds_idx)
388410
self._system_nframes: list[int] = [
389411
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
390412
]
391413
else:
392-
self._system_groups = {0: list(retained_indices)}
393-
self._system_nframes = [len(retained_indices)]
414+
self._system_groups = {0: list(range(len(retained_keys)))}
415+
self._system_nframes = [len(retained_keys)]
394416

395-
# nframes now reflects retained frames; __len__ returns this.
396-
self.nframes = len(retained_indices)
417+
# nframes now reflects retained frames; __len__ returns this and the
418+
# valid index domain for __getitem__ is [0, self.nframes).
419+
self.nframes = len(retained_keys)
397420

398421
# Default batch_size used only by the index/total_batch estimate. The
399422
# sampler always goes through get_batch_size_for_nloc for real batches.
@@ -474,11 +497,21 @@ def __len__(self) -> int:
474497
return self.nframes
475498

476499
def __getitem__(self, index: int) -> dict[str, Any]:
477-
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays."""
478-
key = format(index, self._frame_fmt).encode()
500+
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays.
501+
502+
``index`` is a dataset-level index in ``[0, len(self))``. Under
503+
``filter:N`` the LMDB key space may have gaps (dropped frames), so
504+
we translate through ``self._retained_keys`` before hitting LMDB.
505+
"""
506+
if index < 0 or index >= self.nframes:
507+
raise IndexError(f"dataset index {index} out of range [0, {self.nframes})")
508+
original_key = self._retained_keys[index]
509+
key = format(original_key, self._frame_fmt).encode()
479510
raw = self._txn.get(key)
480511
if raw is None:
481-
raise IndexError(f"Frame {index} not found in LMDB")
512+
raise IndexError(
513+
f"Frame {original_key} not found in LMDB (dataset index {index})"
514+
)
482515
frame = _decode_frame(raw)
483516
frame = _remap_keys(frame)
484517

@@ -607,7 +640,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
607640
np.float32(1.0) if extra_key in frame else np.float32(0.0)
608641
)
609642

610-
frame["fid"] = index
643+
frame["fid"] = original_key
611644

612645
return frame
613646

deepmd/utils/argcheck.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3677,8 +3677,8 @@ def training_data_args() -> list[
36773677
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
36783678
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
36793679
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\
3680-
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
3681-
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\
3680+
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
3681+
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.\n\n\
36823682
If MPI is used, the value should be considered as the batch size per task.'
36833683
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
36843684
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
@@ -3758,8 +3758,8 @@ def validation_data_args() -> list[
37583758
- int: all {link_sys} use the same batch size.\n\n\
37593759
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
37603760
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
3761-
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
3762-
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.'
3761+
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
3762+
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.'
37633763
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
37643764
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
37653765
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\

source/tests/common/dpmodel/test_lmdb_data.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,46 @@ def test_filter_preserves_system_id_numbering(self):
821821
# after re-normalisation → no expansion needed.
822822
self.assertEqual(block_targets, [])
823823

824+
def test_filter_dataset_index_is_contiguous_and_live(self):
825+
"""After filter:N, every i in range(len(reader)) is a live retrievable frame.
826+
827+
Regression for the earlier indexing bug where ``len(reader)`` shrank
828+
to the retained count but ``__getitem__`` still indexed the original
829+
LMDB key space. Under filter:10 the mixed-nloc LMDB drops the two
830+
12-atom frames at original keys 8 & 9; we check here that:
831+
832+
* every dataset index ``0..len(reader)-1`` decodes without raising
833+
and never returns a filtered-out frame, and
834+
* ``fid`` reports the stable original LMDB key, not the dataset
835+
index (so downstream logs survive the remap), and
836+
* out-of-range indices still raise IndexError.
837+
"""
838+
reader = LmdbDataReader(
839+
self._mixed_path, self._type_map, batch_size="filter:10"
840+
)
841+
self.assertEqual(len(reader), 8)
842+
self.assertEqual(len(reader._retained_keys), 8)
843+
self.assertEqual(reader._retained_keys, [0, 1, 2, 3, 4, 5, 6, 7])
844+
845+
seen_fids = []
846+
for i in range(len(reader)):
847+
frame = reader[i]
848+
self.assertLessEqual(frame["atype"].shape[0], 10)
849+
self.assertEqual(
850+
frame["fid"],
851+
reader._retained_keys[i],
852+
msg=f"fid should be the original LMDB key, not dataset index {i}",
853+
)
854+
seen_fids.append(frame["fid"])
855+
# Dropped original keys (8, 9) must never appear as fids.
856+
self.assertNotIn(8, seen_fids)
857+
self.assertNotIn(9, seen_fids)
858+
859+
with self.assertRaises(IndexError):
860+
reader[len(reader)]
861+
with self.assertRaises(IndexError):
862+
reader[-1]
863+
824864
def test_sampler_with_filter(self):
825865
"""SameNlocBatchSampler only emits retained, same-nloc frames."""
826866
reader = LmdbDataReader(
@@ -841,9 +881,11 @@ def test_sampler_with_filter(self):
841881
for batch in all_batches:
842882
nlocs = {reader.frame_nlocs[idx] for idx in batch}
843883
self.assertEqual(len(nlocs), 1)
844-
# The 12-atom frames (indices 8, 9) are never reached.
845-
for idx in (8, 9):
846-
self.assertNotIn(idx, all_indices)
884+
# The 12-atom frames were at original LMDB keys 8, 9; they must
885+
# never be reachable via any emitted dataset index.
886+
reached_original_keys = {reader._retained_keys[idx] for idx in all_indices}
887+
for original_key in (8, 9):
888+
self.assertNotIn(original_key, reached_original_keys)
847889

848890
def test_auto_prob_with_filter_still_works(self):
849891
"""compute_block_targets + sampler survive a fully-dropped block."""

0 commit comments

Comments
 (0)