Skip to content

Commit 9d63816

Browse files
authored
feat(pt/dpmodel): add max and filter mode for lmdb (#5413)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * batch_size accepts "max:N" and "filter:N" in addition to "auto"/"auto:N"; batch-size calculation honors per-frame atom counts. * print_summary explicitly reports the active batch-size rule. * **Bug Fixes** * Dataset length, indexing, and returned frame IDs consistently reflect filtering; filtering preserves original system numbering. * Empty probability blocks are removed and weights renormalized so sampling remains valid even when systems/frames are fully dropped. * "filter:N" usage is disallowed with mixed-batch mode. * **Documentation** * Updated batch_size docs and validation help to describe "max:N" and "filter:N" semantics. * **Tests** * Added tests covering max/filter behaviors, filtering effects on indexing and sampling, error cases for invalid batch_size strings, and handling of fully filtered systems. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent a3f548d commit 9d63816

4 files changed

Lines changed: 545 additions & 34 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 225 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,29 @@ def _compute_batch_size(nloc: int, rule: int) -> int:
208208
return max(bsi, 1)
209209

210210

211+
def _parse_positive_rule(spec: str, prefix: str) -> int:
212+
"""Parse the ``N`` in ``<prefix>N`` and require ``N > 0``.
213+
214+
Rejects missing/non-integer/non-positive ``N`` up front so that
215+
misconfigurations (``"filter:"``, ``"filter:0"``, ``"max:-5"``) fail at
216+
construction time instead of silently producing an empty dataset or a
217+
batch_size=1 fallback downstream.
218+
"""
219+
_, _, raw = spec.partition(":")
220+
try:
221+
n = int(raw)
222+
except ValueError:
223+
raise ValueError(
224+
f"Unsupported batch_size {spec!r}. "
225+
f"Expected '{prefix}N' with N a positive integer."
226+
) from None
227+
if n <= 0:
228+
raise ValueError(
229+
f"Unsupported batch_size {spec!r}: N must be a positive integer, got {n}."
230+
)
231+
return n
232+
233+
211234
class LmdbDataReader:
212235
"""Framework-agnostic LMDB dataset reader.
213236
@@ -232,7 +255,22 @@ class LmdbDataReader:
232255
type_map : list[str]
233256
Global type map from model config.
234257
batch_size : int or str
235-
Batch size. Supports int, "auto", "auto:N".
258+
Batch size rule used to derive per-nloc batch sizes. Supports:
259+
260+
- ``int``: fixed, identical batch size for every nloc group.
261+
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
262+
(``N=32`` for bare ``"auto"``). Acts as a *lower* bound —
263+
each batch has at least ``N`` atoms, but may exceed ``N``
264+
by up to ``nloc - 1``.
265+
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
266+
Acts as an *upper* bound for groups with ``nloc <= N``
267+
(batch has at most ``N`` atoms). For groups with
268+
``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1``
269+
and a single-frame batch still carries ``nloc`` atoms,
270+
which exceeds ``N``.
271+
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and**
272+
drops every frame whose ``nloc > N`` from the dataset. By
273+
construction every retained batch has at most ``N`` atoms.
236274
mixed_batch : bool
237275
If True, allow different nloc in the same batch (future).
238276
If False (default), enforce same-nloc-per-batch.
@@ -283,51 +321,139 @@ def __init__(
283321

284322
# Scan per-frame nloc only when needed for same-nloc batching.
285323
# For mixed_batch=True, skip the scan entirely (future: padding handles it).
324+
# ``orig_frame_nlocs`` / ``orig_frame_system_ids`` are indexed by the
325+
# *original* LMDB frame index. After a potential ``filter:N`` drop we
326+
# rebuild ``self._frame_nlocs`` / ``self._frame_system_ids`` so they
327+
# are parallel arrays over the *dataset* index space (0..len(self));
328+
# the dataset-to-original mapping lives in ``self._retained_keys``.
286329
if not mixed_batch:
287330
# Fast path: use pre-computed frame_nlocs from metadata if available.
288331
# Falls back to scanning each frame's atom_types shape (~10 us/frame).
289332
meta_nlocs = meta.get("frame_nlocs")
290333
if meta_nlocs is not None:
291-
self._frame_nlocs = [int(n) for n in meta_nlocs]
334+
orig_frame_nlocs = [int(n) for n in meta_nlocs]
292335
else:
293-
self._frame_nlocs = _scan_frame_nlocs(
336+
orig_frame_nlocs = _scan_frame_nlocs(
294337
self._env, self.nframes, self._frame_fmt, self._natoms
295338
)
296-
self._nloc_groups: dict[int, list[int]] = {}
297-
for idx, nloc in enumerate(self._frame_nlocs):
298-
self._nloc_groups.setdefault(nloc, []).append(idx)
299339
else:
300-
self._frame_nlocs = []
301-
self._nloc_groups = {}
340+
orig_frame_nlocs = []
302341

303-
# Parse frame_system_ids for auto_prob support
342+
# Parse frame_system_ids for auto_prob support. ``_nsystems`` must stay
343+
# at ``max(original_sid) + 1`` even after filter:N so that user-facing
344+
# auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
345+
# keeps its meaning across filter thresholds.
304346
meta_sys_ids = meta.get("frame_system_ids")
305347
if meta_sys_ids is not None:
306-
self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
307-
self._nsystems = max(self._frame_system_ids) + 1
308-
self._system_groups: dict[int, list[int]] = {}
309-
for idx, sid in enumerate(self._frame_system_ids):
310-
self._system_groups.setdefault(sid, []).append(idx)
311-
self._system_nframes: list[int] = [
312-
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
313-
]
348+
orig_frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
349+
self._nsystems = max(orig_frame_system_ids) + 1
314350
else:
315-
self._frame_system_ids = None
351+
orig_frame_system_ids = None
316352
self._nsystems = 1
317-
self._system_groups = {0: list(range(self.nframes))}
318-
self._system_nframes = [self.nframes]
319353

320-
# Parse batch_size spec
354+
# Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
355+
# exclusive; ``filter_rule`` implies ``max_rule`` plus dropping frames
356+
# whose nloc exceeds the threshold.
321357
self._auto_rule: int | None = None
358+
self._max_rule: int | None = None
359+
self._filter_rule: int | None = None
322360
if isinstance(batch_size, str):
323361
if batch_size == "auto":
324362
self._auto_rule = 32
325363
elif batch_size.startswith("auto:"):
326-
self._auto_rule = int(batch_size.split(":")[1])
364+
self._auto_rule = _parse_positive_rule(batch_size, "auto:")
365+
elif batch_size.startswith("max:"):
366+
self._max_rule = _parse_positive_rule(batch_size, "max:")
367+
elif batch_size.startswith("filter:"):
368+
self._filter_rule = _parse_positive_rule(batch_size, "filter:")
369+
self._max_rule = self._filter_rule
327370
else:
328-
self._auto_rule = 32
329-
# Default batch_size uses first frame's nloc (for total_batch estimate)
371+
raise ValueError(
372+
f"Unsupported batch_size {batch_size!r}. "
373+
"Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'."
374+
)
375+
376+
# ``filter:N`` needs per-frame nloc to drop oversized frames; the
377+
# ``mixed_batch=True`` fast path skips the nloc scan entirely, so the
378+
# two options are incompatible. Fail fast rather than silently
379+
# retaining every frame and breaking the documented contract.
380+
if self._filter_rule is not None and mixed_batch:
381+
raise ValueError(
382+
"batch_size='filter:N' is incompatible with mixed_batch=True: "
383+
"per-frame nloc is unavailable in the mixed-batch fast path. "
384+
"Use mixed_batch=False, or switch to 'max:N' / a fixed int."
385+
)
386+
387+
# Determine which original-index frames survive the filter. Without
388+
# ``filter:N`` every frame is retained.
389+
if self._filter_rule is not None:
390+
retained_keys = [
391+
i for i, n in enumerate(orig_frame_nlocs) if n <= self._filter_rule
392+
]
393+
n_dropped = self.nframes - len(retained_keys)
394+
if n_dropped > 0:
395+
log.info(
396+
f"LMDB filter:{self._filter_rule} drops {n_dropped}/"
397+
f"{self.nframes} frames with nloc > {self._filter_rule} "
398+
f"({self.lmdb_path})."
399+
)
400+
else:
401+
retained_keys = list(range(self.nframes))
402+
403+
# Dataset-index → original LMDB frame key. ``__getitem__`` looks up
404+
# this table so that ``reader[i]`` is a valid LMDB read for every
405+
# ``0 <= i < len(reader)``, no matter how many frames were filtered.
406+
self._retained_keys: list[int] = retained_keys
407+
408+
# Re-key _frame_nlocs / _frame_system_ids into the dataset-index
409+
# space so that every downstream consumer (nloc_groups, system_groups,
410+
# SameNlocBatchSampler, _expand_indices_by_blocks) operates in a
411+
# single, self-consistent indexing scheme.
412+
if not mixed_batch:
413+
self._frame_nlocs = [orig_frame_nlocs[k] for k in retained_keys]
414+
else:
415+
self._frame_nlocs = []
416+
417+
if orig_frame_system_ids is not None:
418+
self._frame_system_ids: list[int] | None = [
419+
orig_frame_system_ids[k] for k in retained_keys
420+
]
421+
else:
422+
self._frame_system_ids = None
423+
424+
# Group retained frames by nloc using dataset indices (0..len-1).
425+
if not mixed_batch:
426+
self._nloc_groups: dict[int, list[int]] = {}
427+
for ds_idx, nloc in enumerate(self._frame_nlocs):
428+
self._nloc_groups.setdefault(nloc, []).append(ds_idx)
429+
else:
430+
self._nloc_groups = {}
431+
432+
# Group retained frames by original system id; the sid numbering is
433+
# preserved (no compression) so user-facing auto_prob slices stay
434+
# meaningful across filter thresholds. Fully-dropped systems appear
435+
# as zero-frame entries in ``_system_nframes``.
436+
if self._frame_system_ids is not None:
437+
self._system_groups: dict[int, list[int]] = {}
438+
for ds_idx, sid in enumerate(self._frame_system_ids):
439+
self._system_groups.setdefault(sid, []).append(ds_idx)
440+
self._system_nframes: list[int] = [
441+
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
442+
]
443+
else:
444+
self._system_groups = {0: list(range(len(retained_keys)))}
445+
self._system_nframes = [len(retained_keys)]
446+
447+
# nframes now reflects retained frames; __len__ returns this and the
448+
# valid index domain for __getitem__ is [0, self.nframes).
449+
self.nframes = len(retained_keys)
450+
451+
# Default batch_size used only by the index/total_batch estimate. The
452+
# sampler always goes through get_batch_size_for_nloc for real batches.
453+
if self._auto_rule is not None:
330454
self.batch_size = _compute_batch_size(self._natoms, self._auto_rule)
455+
elif self._max_rule is not None:
456+
self.batch_size = max(1, self._max_rule // max(self._natoms, 1))
331457
else:
332458
self.batch_size = int(batch_size)
333459

@@ -382,20 +508,44 @@ def __del__(self) -> None:
382508
_close_lmdb(path)
383509

384510
def get_batch_size_for_nloc(self, nloc: int) -> int:
385-
"""Get batch_size for a given nloc. Uses auto rule if configured."""
511+
"""Return the per-nloc batch size for the configured rule.
512+
513+
- ``auto`` / ``auto:N``: ``ceil(N / nloc)`` — may overshoot the
514+
atom budget by up to ``nloc - 1`` atoms.
515+
- ``max:N``: ``max(1, floor(N / nloc))``. Acts as an upper bound
516+
for groups with ``nloc <= N`` (batch has at most ``N`` atoms).
517+
For groups with ``nloc > N`` the floor clamps to 1 and the
518+
single-frame batch still carries ``nloc`` atoms, exceeding ``N``.
519+
- ``filter:N``: same per-nloc formula as ``max:N``; by
520+
construction every retained group satisfies ``nloc <= N`` so
521+
no overshoot occurs.
522+
- fixed int: the same value for every nloc group.
523+
"""
386524
if self._auto_rule is not None:
387525
return _compute_batch_size(nloc, self._auto_rule)
526+
if self._max_rule is not None:
527+
return max(1, self._max_rule // max(nloc, 1))
388528
return self.batch_size
389529

390530
def __len__(self) -> int:
391531
return self.nframes
392532

393533
def __getitem__(self, index: int) -> dict[str, Any]:
394-
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays."""
395-
key = format(index, self._frame_fmt).encode()
534+
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays.
535+
536+
``index`` is a dataset-level index in ``[0, len(self))``. Under
537+
``filter:N`` the LMDB key space may have gaps (dropped frames), so
538+
we translate through ``self._retained_keys`` before hitting LMDB.
539+
"""
540+
if index < 0 or index >= self.nframes:
541+
raise IndexError(f"dataset index {index} out of range [0, {self.nframes})")
542+
original_key = self._retained_keys[index]
543+
key = format(original_key, self._frame_fmt).encode()
396544
raw = self._txn.get(key)
397545
if raw is None:
398-
raise IndexError(f"Frame {index} not found in LMDB")
546+
raise IndexError(
547+
f"Frame {original_key} not found in LMDB (dataset index {index})"
548+
)
399549
frame = _decode_frame(raw)
400550
frame = _remap_keys(frame)
401551

@@ -524,7 +674,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
524674
np.float32(1.0) if extra_key in frame else np.float32(0.0)
525675
)
526676

527-
frame["fid"] = index
677+
frame["fid"] = original_key
528678

529679
return frame
530680

@@ -543,11 +693,19 @@ def data_requirements(self) -> list[DataRequirementItem]:
543693
def print_summary(self, name: str, prob: Any) -> None:
544694
"""Print basic dataset info."""
545695
n_groups = len(self._nloc_groups)
696+
if self._auto_rule is not None:
697+
bs_str = f"auto:{self._auto_rule}"
698+
elif self._filter_rule is not None:
699+
bs_str = f"filter:{self._filter_rule}"
700+
elif self._max_rule is not None:
701+
bs_str = f"max:{self._max_rule}"
702+
else:
703+
bs_str = str(self.batch_size)
546704

547705
log.info(
548706
f"LMDB {name}: {self.lmdb_path}, "
549707
f"{self.nframes} frames, {n_groups} nloc groups, "
550-
f"batch_size={'auto' if self._auto_rule else self.batch_size}, "
708+
f"batch_size={bs_str}, "
551709
f"mixed_batch={self.mixed_batch}"
552710
)
553711
# Print nloc groups in rows of ~10 for readability
@@ -691,6 +849,43 @@ def compute_block_targets(
691849
stt, end, weight = part.split(":")
692850
blocks.append((int(stt), int(end), float(weight)))
693851

852+
# Drop blocks that retain zero frames (can happen when ``filter:N``
853+
# eliminates every system in a block). prob_sys_size_ext's per-block
854+
# ``nbatch_block / sum(nbatch_block)`` would otherwise propagate NaN
855+
# when the whole block sums to zero. An all-zero dataset yields no
856+
# targets at all.
857+
nonempty = [
858+
(stt, end, weight)
859+
for stt, end, weight in blocks
860+
if sum(system_nframes[stt:end]) > 0
861+
]
862+
if not nonempty:
863+
log.info(
864+
"compute_block_targets: all blocks are empty in "
865+
f"{auto_prob_style!r}; dataset has no retained frames."
866+
)
867+
return []
868+
if len(nonempty) < len(blocks):
869+
# Rewriting auto_prob_style silently re-normalises the remaining
870+
# weights so they sum to 1.0 — e.g. ``0:3:0.8;3:10:0.2`` with block
871+
# ``0:3`` empty becomes effectively weight 1.0 on block ``3:10``.
872+
# Surface this reweighting so operators can correlate it with the
873+
# preceding ``filter:N`` log line.
874+
dropped = [
875+
f"{stt}:{end}:{weight}"
876+
for (stt, end, weight) in blocks
877+
if (stt, end, weight) not in nonempty
878+
]
879+
log.info(
880+
"compute_block_targets: dropping empty blocks (all systems have "
881+
f"0 frames, likely after filter:N): {dropped}. Remaining block "
882+
"weights will be renormalised to sum to 1.0."
883+
)
884+
auto_prob_style = "prob_sys_size;" + ";".join(
885+
f"{stt}:{end}:{weight}" for stt, end, weight in nonempty
886+
)
887+
blocks = nonempty
888+
694889
# Compute per-system probabilities using the standard function
695890
sys_probs = prob_sys_size_ext(auto_prob_style, nsystems, system_nframes)
696891

deepmd/pt/utils/lmdb_dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,14 @@ class LmdbDataset(Dataset):
109109
type_map : list[str]
110110
Global type map from model config.
111111
batch_size : int or str
112-
Batch size. Supports int, "auto", "auto:N".
112+
Batch size rule forwarded to :class:`LmdbDataReader`. Supports:
113+
114+
- ``int``: fixed batch size for every nloc group.
115+
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
116+
(``N=32`` for bare ``"auto"``).
117+
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
118+
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` and drops
119+
every frame whose ``nloc > N`` from the dataset.
113120
mixed_batch : bool
114121
If True, allow different nloc in the same batch (future).
115122
If False (default), use SameNlocBatchSampler.

deepmd/utils/argcheck.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3718,8 +3718,8 @@ def training_data_args() -> list[
37183718
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
37193719
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
37203720
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\
3721-
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
3722-
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\
3721+
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
3722+
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.\n\n\
37233723
If MPI is used, the value should be considered as the batch size per task.'
37243724
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
37253725
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
@@ -3798,7 +3798,9 @@ def validation_data_args() -> list[
37983798
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
37993799
- int: all {link_sys} use the same batch size.\n\n\
38003800
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
3801-
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
3801+
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
3802+
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
3803+
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.'
38023804
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
38033805
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
38043806
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\

0 commit comments

Comments
 (0)