Skip to content

Commit c3fc8e8

Browse files
committed
feat(pt/dpmodel): add max and filter mode for lmdb
1 parent 5d9cbdf commit c3fc8e8

4 files changed

Lines changed: 360 additions & 21 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,22 @@ class LmdbDataReader:
232232
type_map : list[str]
233233
Global type map from model config.
234234
batch_size : int or str
235-
Batch size. Supports int, "auto", "auto:N".
235+
Batch size rule used to derive per-nloc batch sizes. Supports:
236+
237+
- ``int``: fixed, identical batch size for every nloc group.
238+
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
239+
(``N=32`` for bare ``"auto"``). Acts as a *lower* bound —
240+
each batch has at least ``N`` atoms, but may exceed ``N``
241+
by up to ``nloc - 1``.
242+
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
243+
Acts as an *upper* bound for groups with ``nloc <= N``
244+
(batch has at most ``N`` atoms). For groups with
245+
``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1``
246+
and a single-frame batch still carries ``nloc`` atoms,
247+
which exceeds ``N``.
248+
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and**
249+
drops every frame whose ``nloc > N`` from the dataset. By
250+
construction every retained batch has at most ``N`` atoms.
236251
mixed_batch : bool
237252
If True, allow different nloc in the same batch (future).
238253
If False (default), enforce same-nloc-per-batch.
@@ -283,6 +298,10 @@ def __init__(
283298

284299
# Scan per-frame nloc only when needed for same-nloc batching.
285300
# For mixed_batch=True, skip the scan entirely (future: padding handles it).
301+
# We keep _frame_nlocs / _frame_system_ids indexable by the *original*
302+
# LMDB frame index even after filter:N: entries for dropped frames
303+
# simply never get referenced because _nloc_groups / _system_groups
304+
# no longer reference them.
286305
if not mixed_batch:
287306
# Fast path: use pre-computed frame_nlocs from metadata if available.
288307
# Falls back to scanning each frame's atom_types shape (~10 us/frame).
@@ -293,41 +312,95 @@ def __init__(
293312
self._frame_nlocs = _scan_frame_nlocs(
294313
self._env, self.nframes, self._frame_fmt, self._natoms
295314
)
296-
self._nloc_groups: dict[int, list[int]] = {}
297-
for idx, nloc in enumerate(self._frame_nlocs):
298-
self._nloc_groups.setdefault(nloc, []).append(idx)
299315
else:
300316
self._frame_nlocs = []
301-
self._nloc_groups = {}
302317

303-
# Parse frame_system_ids for auto_prob support
318+
# Parse frame_system_ids for auto_prob support. _nsystems must stay at
319+
# ``max(original_sid) + 1`` even after filter:N so that user-facing
320+
# auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
321+
# keeps its meaning across filter thresholds.
304322
meta_sys_ids = meta.get("frame_system_ids")
305323
if meta_sys_ids is not None:
306324
self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
307325
self._nsystems = max(self._frame_system_ids) + 1
308-
self._system_groups: dict[int, list[int]] = {}
309-
for idx, sid in enumerate(self._frame_system_ids):
310-
self._system_groups.setdefault(sid, []).append(idx)
311-
self._system_nframes: list[int] = [
312-
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
313-
]
314326
else:
315327
self._frame_system_ids = None
316328
self._nsystems = 1
317-
self._system_groups = {0: list(range(self.nframes))}
318-
self._system_nframes = [self.nframes]
319329

320-
# Parse batch_size spec
330+
# Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
331+
# exclusive; ``filter_rule`` implies ``max_rule`` plus dropping frames
332+
# whose nloc exceeds the threshold.
321333
self._auto_rule: int | None = None
334+
self._max_rule: int | None = None
335+
self._filter_rule: int | None = None
322336
if isinstance(batch_size, str):
323337
if batch_size == "auto":
324338
self._auto_rule = 32
325339
elif batch_size.startswith("auto:"):
326340
self._auto_rule = int(batch_size.split(":")[1])
341+
elif batch_size.startswith("max:"):
342+
self._max_rule = int(batch_size.split(":")[1])
343+
elif batch_size.startswith("filter:"):
344+
self._filter_rule = int(batch_size.split(":")[1])
345+
self._max_rule = self._filter_rule
327346
else:
328-
self._auto_rule = 32
329-
# Default batch_size uses first frame's nloc (for total_batch estimate)
347+
raise ValueError(
348+
f"Unsupported batch_size {batch_size!r}. "
349+
"Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'."
350+
)
351+
352+
# Determine which original-index frames survive the filter. Without
353+
# ``filter:N`` every frame is retained. ``mixed_batch=True`` has no
354+
# per-frame nloc info to filter against, so filter:N is a no-op there.
355+
if self._filter_rule is not None and not mixed_batch:
356+
retained_indices = [
357+
i for i, n in enumerate(self._frame_nlocs) if n <= self._filter_rule
358+
]
359+
n_dropped = self.nframes - len(retained_indices)
360+
if n_dropped > 0:
361+
log.info(
362+
f"LMDB filter:{self._filter_rule} drops {n_dropped}/"
363+
f"{self.nframes} frames with nloc > {self._filter_rule} "
364+
f"({self.lmdb_path})."
365+
)
366+
else:
367+
retained_indices = list(range(self.nframes))
368+
369+
# Group retained frames by nloc. _nloc_groups only contains nlocs
370+
# that passed the filter; its values stay as *original* LMDB frame
371+
# indices so __getitem__(index) keeps reading the right LMDB key.
372+
if not mixed_batch:
373+
self._nloc_groups: dict[int, list[int]] = {}
374+
for idx in retained_indices:
375+
self._nloc_groups.setdefault(self._frame_nlocs[idx], []).append(idx)
376+
else:
377+
self._nloc_groups = {}
378+
379+
# Group retained frames by system id. _system_nframes is indexed by
380+
# *original* sid and stays length _nsystems even if some systems are
381+
# fully dropped — those entries are simply zero so auto_prob block
382+
# slicing still parses predictably.
383+
if self._frame_system_ids is not None:
384+
self._system_groups: dict[int, list[int]] = {}
385+
for idx in retained_indices:
386+
sid = self._frame_system_ids[idx]
387+
self._system_groups.setdefault(sid, []).append(idx)
388+
self._system_nframes: list[int] = [
389+
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
390+
]
391+
else:
392+
self._system_groups = {0: list(retained_indices)}
393+
self._system_nframes = [len(retained_indices)]
394+
395+
# nframes now reflects retained frames; __len__ returns this.
396+
self.nframes = len(retained_indices)
397+
398+
# Default batch_size used only by the index/total_batch estimate. The
399+
# sampler always goes through get_batch_size_for_nloc for real batches.
400+
if self._auto_rule is not None:
330401
self.batch_size = _compute_batch_size(self._natoms, self._auto_rule)
402+
elif self._max_rule is not None:
403+
self.batch_size = max(1, self._max_rule // max(self._natoms, 1))
331404
else:
332405
self.batch_size = int(batch_size)
333406

@@ -382,9 +455,19 @@ def __del__(self) -> None:
382455
_close_lmdb(path)
383456

384457
def get_batch_size_for_nloc(self, nloc: int) -> int:
385-
"""Get batch_size for a given nloc. Uses auto rule if configured."""
458+
"""Return the per-nloc batch size for the configured rule.
459+
460+
- ``auto`` / ``auto:N``: ``ceil(N / nloc)`` — may overshoot the
461+
atom budget by up to ``nloc - 1`` atoms.
462+
- ``max:N`` / ``filter:N``: ``max(1, floor(N / nloc))`` — never
463+
overshoots; clamps to 1 when a single frame already exceeds ``N``
464+
atoms.
465+
- fixed int: the same value for every nloc group.
466+
"""
386467
if self._auto_rule is not None:
387468
return _compute_batch_size(nloc, self._auto_rule)
469+
if self._max_rule is not None:
470+
return max(1, self._max_rule // max(nloc, 1))
388471
return self.batch_size
389472

390473
def __len__(self) -> int:
@@ -538,11 +621,19 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
538621
def print_summary(self, name: str, prob: Any) -> None:
539622
"""Print basic dataset info."""
540623
n_groups = len(self._nloc_groups)
624+
if self._auto_rule is not None:
625+
bs_str = f"auto:{self._auto_rule}"
626+
elif self._filter_rule is not None:
627+
bs_str = f"filter:{self._filter_rule}"
628+
elif self._max_rule is not None:
629+
bs_str = f"max:{self._max_rule}"
630+
else:
631+
bs_str = str(self.batch_size)
541632

542633
log.info(
543634
f"LMDB {name}: {self.lmdb_path}, "
544635
f"{self.nframes} frames, {n_groups} nloc groups, "
545-
f"batch_size={'auto' if self._auto_rule else self.batch_size}, "
636+
f"batch_size={bs_str}, "
546637
f"mixed_batch={self.mixed_batch}"
547638
)
548639
# Print nloc groups in rows of ~10 for readability
@@ -646,6 +737,24 @@ def compute_block_targets(
646737
stt, end, weight = part.split(":")
647738
blocks.append((int(stt), int(end), float(weight)))
648739

740+
# Drop blocks that retain zero frames (can happen when ``filter:N``
741+
# eliminates every system in a block). prob_sys_size_ext's per-block
742+
# ``nbatch_block / sum(nbatch_block)`` would otherwise propagate NaN
743+
# when the whole block sums to zero. An all-zero dataset yields no
744+
# targets at all.
745+
nonempty = [
746+
(stt, end, weight)
747+
for stt, end, weight in blocks
748+
if sum(system_nframes[stt:end]) > 0
749+
]
750+
if not nonempty:
751+
return []
752+
if len(nonempty) < len(blocks):
753+
auto_prob_style = "prob_sys_size;" + ";".join(
754+
f"{stt}:{end}:{weight}" for stt, end, weight in nonempty
755+
)
756+
blocks = nonempty
757+
649758
# Compute per-system probabilities using the standard function
650759
sys_probs = prob_sys_size_ext(auto_prob_style, nsystems, system_nframes)
651760

deepmd/pt/utils/lmdb_dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,14 @@ class LmdbDataset(Dataset):
112112
type_map : list[str]
113113
Global type map from model config.
114114
batch_size : int or str
115-
Batch size. Supports int, "auto", "auto:N".
115+
Batch size rule forwarded to :class:`LmdbDataReader`. Supports:
116+
117+
- ``int``: fixed batch size for every nloc group.
118+
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
119+
(``N=32`` for bare ``"auto"``).
120+
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
121+
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` and drops
122+
every frame whose ``nloc > N`` from the dataset.
116123
mixed_batch : bool
117124
If True, allow different nloc in the same batch (future).
118125
If False (default), use SameNlocBatchSampler.

deepmd/utils/argcheck.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3757,7 +3757,9 @@ def validation_data_args() -> list[
37573757
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
37583758
- int: all {link_sys} use the same batch size.\n\n\
37593759
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
3760-
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
3760+
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
3761+
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
3762+
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.'
37613763
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
37623764
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
37633765
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\

0 commit comments

Comments
 (0)