Skip to content

Commit fc02035

Browse files
committed
fixup lmdb
1 parent c3fc8e8 commit fc02035

3 files changed

Lines changed: 115 additions & 43 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -298,33 +298,34 @@ def __init__(
298298

299299
# Scan per-frame nloc only when needed for same-nloc batching.
300300
# For mixed_batch=True, skip the scan entirely (future: padding handles it).
301-
# We keep _frame_nlocs / _frame_system_ids indexable by the *original*
302-
# LMDB frame index even after filter:N: entries for dropped frames
303-
# simply never get referenced because _nloc_groups / _system_groups
304-
# no longer reference them.
301+
# ``orig_frame_nlocs`` / ``orig_frame_system_ids`` are indexed by the
302+
# *original* LMDB frame index. After a potential ``filter:N`` drop we
303+
# rebuild ``self._frame_nlocs`` / ``self._frame_system_ids`` so they
304+
# are parallel arrays over the *dataset* index space (0..len(self));
305+
# the dataset-to-original mapping lives in ``self._retained_keys``.
305306
if not mixed_batch:
306307
# Fast path: use pre-computed frame_nlocs from metadata if available.
307308
# Falls back to scanning each frame's atom_types shape (~10 us/frame).
308309
meta_nlocs = meta.get("frame_nlocs")
309310
if meta_nlocs is not None:
310-
self._frame_nlocs = [int(n) for n in meta_nlocs]
311+
orig_frame_nlocs = [int(n) for n in meta_nlocs]
311312
else:
312-
self._frame_nlocs = _scan_frame_nlocs(
313+
orig_frame_nlocs = _scan_frame_nlocs(
313314
self._env, self.nframes, self._frame_fmt, self._natoms
314315
)
315316
else:
316-
self._frame_nlocs = []
317+
orig_frame_nlocs = []
317318

318-
# Parse frame_system_ids for auto_prob support. _nsystems must stay at
319-
# ``max(original_sid) + 1`` even after filter:N so that user-facing
319+
# Parse frame_system_ids for auto_prob support. ``_nsystems`` must stay
320+
# at ``max(original_sid) + 1`` even after filter:N so that user-facing
320321
# auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
321322
# keeps its meaning across filter thresholds.
322323
meta_sys_ids = meta.get("frame_system_ids")
323324
if meta_sys_ids is not None:
324-
self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
325-
self._nsystems = max(self._frame_system_ids) + 1
325+
orig_frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
326+
self._nsystems = max(orig_frame_system_ids) + 1
326327
else:
327-
self._frame_system_ids = None
328+
orig_frame_system_ids = None
328329
self._nsystems = 1
329330

330331
# Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
@@ -353,47 +354,66 @@ def __init__(
353354
# ``filter:N`` every frame is retained. ``mixed_batch=True`` has no
354355
# per-frame nloc info to filter against, so filter:N is a no-op there.
355356
if self._filter_rule is not None and not mixed_batch:
356-
retained_indices = [
357-
i for i, n in enumerate(self._frame_nlocs) if n <= self._filter_rule
357+
retained_keys = [
358+
i for i, n in enumerate(orig_frame_nlocs) if n <= self._filter_rule
358359
]
359-
n_dropped = self.nframes - len(retained_indices)
360+
n_dropped = self.nframes - len(retained_keys)
360361
if n_dropped > 0:
361362
log.info(
362363
f"LMDB filter:{self._filter_rule} drops {n_dropped}/"
363364
f"{self.nframes} frames with nloc > {self._filter_rule} "
364365
f"({self.lmdb_path})."
365366
)
366367
else:
367-
retained_indices = list(range(self.nframes))
368+
retained_keys = list(range(self.nframes))
369+
370+
# Dataset-index → original LMDB frame key. ``__getitem__`` looks up
371+
# this table so that ``reader[i]`` is a valid LMDB read for every
372+
# ``0 <= i < len(reader)``, no matter how many frames were filtered.
373+
self._retained_keys: list[int] = retained_keys
374+
375+
# Re-key _frame_nlocs / _frame_system_ids into the dataset-index
376+
# space so that every downstream consumer (nloc_groups, system_groups,
377+
# SameNlocBatchSampler, _expand_indices_by_blocks) operates in a
378+
# single, self-consistent indexing scheme.
379+
if not mixed_batch:
380+
self._frame_nlocs = [orig_frame_nlocs[k] for k in retained_keys]
381+
else:
382+
self._frame_nlocs = []
383+
384+
if orig_frame_system_ids is not None:
385+
self._frame_system_ids: list[int] | None = [
386+
orig_frame_system_ids[k] for k in retained_keys
387+
]
388+
else:
389+
self._frame_system_ids = None
368390

369-
# Group retained frames by nloc. _nloc_groups only contains nlocs
370-
# that passed the filter; its values stay as *original* LMDB frame
371-
# indices so __getitem__(index) keeps reading the right LMDB key.
391+
# Group retained frames by nloc using dataset indices (0..len-1).
372392
if not mixed_batch:
373393
self._nloc_groups: dict[int, list[int]] = {}
374-
for idx in retained_indices:
375-
self._nloc_groups.setdefault(self._frame_nlocs[idx], []).append(idx)
394+
for ds_idx, nloc in enumerate(self._frame_nlocs):
395+
self._nloc_groups.setdefault(nloc, []).append(ds_idx)
376396
else:
377397
self._nloc_groups = {}
378398

379-
# Group retained frames by system id. _system_nframes is indexed by
380-
# *original* sid and stays length _nsystems even if some systems are
381-
# fully dropped — those entries are simply zero so auto_prob block
382-
# slicing still parses predictably.
399+
# Group retained frames by original system id; the sid numbering is
400+
# preserved (no compression) so user-facing auto_prob slices stay
401+
# meaningful across filter thresholds. Fully-dropped systems appear
402+
# as zero-frame entries in ``_system_nframes``.
383403
if self._frame_system_ids is not None:
384404
self._system_groups: dict[int, list[int]] = {}
385-
for idx in retained_indices:
386-
sid = self._frame_system_ids[idx]
387-
self._system_groups.setdefault(sid, []).append(idx)
405+
for ds_idx, sid in enumerate(self._frame_system_ids):
406+
self._system_groups.setdefault(sid, []).append(ds_idx)
388407
self._system_nframes: list[int] = [
389408
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
390409
]
391410
else:
392-
self._system_groups = {0: list(retained_indices)}
393-
self._system_nframes = [len(retained_indices)]
411+
self._system_groups = {0: list(range(len(retained_keys)))}
412+
self._system_nframes = [len(retained_keys)]
394413

395-
# nframes now reflects retained frames; __len__ returns this.
396-
self.nframes = len(retained_indices)
414+
# nframes now reflects retained frames; __len__ returns this and the
415+
# valid index domain for __getitem__ is [0, self.nframes).
416+
self.nframes = len(retained_keys)
397417

398418
# Default batch_size used only by the index/total_batch estimate. The
399419
# sampler always goes through get_batch_size_for_nloc for real batches.
@@ -474,11 +494,21 @@ def __len__(self) -> int:
474494
return self.nframes
475495

476496
def __getitem__(self, index: int) -> dict[str, Any]:
477-
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays."""
478-
key = format(index, self._frame_fmt).encode()
497+
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays.
498+
499+
``index`` is a dataset-level index in ``[0, len(self))``. Under
500+
``filter:N`` the LMDB key space may have gaps (dropped frames), so
501+
we translate through ``self._retained_keys`` before hitting LMDB.
502+
"""
503+
if index < 0 or index >= self.nframes:
504+
raise IndexError(f"dataset index {index} out of range [0, {self.nframes})")
505+
original_key = self._retained_keys[index]
506+
key = format(original_key, self._frame_fmt).encode()
479507
raw = self._txn.get(key)
480508
if raw is None:
481-
raise IndexError(f"Frame {index} not found in LMDB")
509+
raise IndexError(
510+
f"Frame {original_key} not found in LMDB (dataset index {index})"
511+
)
482512
frame = _decode_frame(raw)
483513
frame = _remap_keys(frame)
484514

@@ -607,7 +637,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
607637
np.float32(1.0) if extra_key in frame else np.float32(0.0)
608638
)
609639

610-
frame["fid"] = index
640+
frame["fid"] = original_key
611641

612642
return frame
613643

deepmd/utils/argcheck.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3677,8 +3677,8 @@ def training_data_args() -> list[
36773677
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
36783678
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
36793679
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\
3680-
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
3681-
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\
3680+
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
3681+
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.\n\n\
36823682
If MPI is used, the value should be considered as the batch size per task.'
36833683
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
36843684
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
@@ -3758,8 +3758,8 @@ def validation_data_args() -> list[
37583758
- int: all {link_sys} use the same batch size.\n\n\
37593759
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
37603760
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
3761-
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
3762-
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.'
3761+
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
3762+
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.'
37633763
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
37643764
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
37653765
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\

source/tests/common/dpmodel/test_lmdb_data.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,46 @@ def test_filter_preserves_system_id_numbering(self):
821821
# after re-normalisation → no expansion needed.
822822
self.assertEqual(block_targets, [])
823823

824+
def test_filter_dataset_index_is_contiguous_and_live(self):
825+
"""After filter:N, every i in range(len(reader)) is a live retrievable frame.
826+
827+
Regression for the earlier indexing bug where ``len(reader)`` shrank
828+
to the retained count but ``__getitem__`` still indexed the original
829+
LMDB key space. Under filter:10 the mixed-nloc LMDB drops the two
830+
12-atom frames at original keys 8 & 9; we check here that:
831+
832+
* every dataset index ``0..len(reader)-1`` decodes without raising
833+
and never returns a filtered-out frame, and
834+
* ``fid`` reports the stable original LMDB key, not the dataset
835+
index (so downstream logs survive the remap), and
836+
* out-of-range indices still raise IndexError.
837+
"""
838+
reader = LmdbDataReader(
839+
self._mixed_path, self._type_map, batch_size="filter:10"
840+
)
841+
self.assertEqual(len(reader), 8)
842+
self.assertEqual(len(reader._retained_keys), 8)
843+
self.assertEqual(reader._retained_keys, [0, 1, 2, 3, 4, 5, 6, 7])
844+
845+
seen_fids = []
846+
for i in range(len(reader)):
847+
frame = reader[i]
848+
self.assertLessEqual(frame["atype"].shape[0], 10)
849+
self.assertEqual(
850+
frame["fid"],
851+
reader._retained_keys[i],
852+
msg=f"fid should be the original LMDB key, not dataset index {i}",
853+
)
854+
seen_fids.append(frame["fid"])
855+
# Dropped original keys (8, 9) must never appear as fids.
856+
self.assertNotIn(8, seen_fids)
857+
self.assertNotIn(9, seen_fids)
858+
859+
with self.assertRaises(IndexError):
860+
reader[len(reader)]
861+
with self.assertRaises(IndexError):
862+
reader[-1]
863+
824864
def test_sampler_with_filter(self):
825865
"""SameNlocBatchSampler only emits retained, same-nloc frames."""
826866
reader = LmdbDataReader(
@@ -841,9 +881,11 @@ def test_sampler_with_filter(self):
841881
for batch in all_batches:
842882
nlocs = {reader.frame_nlocs[idx] for idx in batch}
843883
self.assertEqual(len(nlocs), 1)
844-
# The 12-atom frames (indices 8, 9) are never reached.
845-
for idx in (8, 9):
846-
self.assertNotIn(idx, all_indices)
884+
# The 12-atom frames were at original LMDB keys 8, 9; they must
885+
# never be reachable via any emitted dataset index.
886+
reached_original_keys = {reader._retained_keys[idx] for idx in all_indices}
887+
for original_key in (8, 9):
888+
self.assertNotIn(original_key, reached_original_keys)
847889

848890
def test_auto_prob_with_filter_still_works(self):
849891
"""compute_block_targets + sampler survive a fully-dropped block."""

0 commit comments

Comments
 (0)