Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 128 additions & 19 deletions deepmd/dpmodel/utils/lmdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,22 @@ class LmdbDataReader:
type_map : list[str]
Global type map from model config.
batch_size : int or str
Batch size. Supports int, "auto", "auto:N".
Batch size rule used to derive per-nloc batch sizes. Supports:

- ``int``: fixed, identical batch size for every nloc group.
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
(``N=32`` for bare ``"auto"``). Acts as a *lower* bound —
each batch has at least ``N`` atoms, but may exceed ``N``
by up to ``nloc - 1``.
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
Acts as an *upper* bound for groups with ``nloc <= N``
(batch has at most ``N`` atoms). For groups with
``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1``
and a single-frame batch still carries ``nloc`` atoms,
which exceeds ``N``.
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and**
drops every frame whose ``nloc > N`` from the dataset. By
construction every retained batch has at most ``N`` atoms.
mixed_batch : bool
If True, allow different nloc in the same batch (future).
If False (default), enforce same-nloc-per-batch.
Expand Down Expand Up @@ -283,6 +298,10 @@ def __init__(

# Scan per-frame nloc only when needed for same-nloc batching.
# For mixed_batch=True, skip the scan entirely (future: padding handles it).
# We keep _frame_nlocs / _frame_system_ids indexable by the *original*
# LMDB frame index even after filter:N: entries for dropped frames
# simply never get referenced because _nloc_groups / _system_groups
# no longer reference them.
if not mixed_batch:
# Fast path: use pre-computed frame_nlocs from metadata if available.
# Falls back to scanning each frame's atom_types shape (~10 us/frame).
Expand All @@ -293,41 +312,95 @@ def __init__(
self._frame_nlocs = _scan_frame_nlocs(
self._env, self.nframes, self._frame_fmt, self._natoms
)
self._nloc_groups: dict[int, list[int]] = {}
for idx, nloc in enumerate(self._frame_nlocs):
self._nloc_groups.setdefault(nloc, []).append(idx)
else:
self._frame_nlocs = []
self._nloc_groups = {}

# Parse frame_system_ids for auto_prob support
# Parse frame_system_ids for auto_prob support. _nsystems must stay at
# ``max(original_sid) + 1`` even after filter:N so that user-facing
# auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
# keeps its meaning across filter thresholds.
meta_sys_ids = meta.get("frame_system_ids")
if meta_sys_ids is not None:
self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
self._nsystems = max(self._frame_system_ids) + 1
self._system_groups: dict[int, list[int]] = {}
for idx, sid in enumerate(self._frame_system_ids):
self._system_groups.setdefault(sid, []).append(idx)
self._system_nframes: list[int] = [
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
]
else:
self._frame_system_ids = None
self._nsystems = 1
self._system_groups = {0: list(range(self.nframes))}
self._system_nframes = [self.nframes]

# Parse batch_size spec
# Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
# exclusive; ``filter_rule`` implies ``max_rule`` plus dropping frames
# whose nloc exceeds the threshold.
self._auto_rule: int | None = None
self._max_rule: int | None = None
self._filter_rule: int | None = None
if isinstance(batch_size, str):
if batch_size == "auto":
self._auto_rule = 32
elif batch_size.startswith("auto:"):
self._auto_rule = int(batch_size.split(":")[1])
elif batch_size.startswith("max:"):
self._max_rule = int(batch_size.split(":")[1])
elif batch_size.startswith("filter:"):
self._filter_rule = int(batch_size.split(":")[1])
self._max_rule = self._filter_rule
else:
self._auto_rule = 32
# Default batch_size uses first frame's nloc (for total_batch estimate)
raise ValueError(
f"Unsupported batch_size {batch_size!r}. "
"Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'."
)
Comment thread
OutisLi marked this conversation as resolved.

# Determine which original-index frames survive the filter. Without
# ``filter:N`` every frame is retained. ``mixed_batch=True`` has no
# per-frame nloc info to filter against, so filter:N is a no-op there.
if self._filter_rule is not None and not mixed_batch:
retained_indices = [
i for i, n in enumerate(self._frame_nlocs) if n <= self._filter_rule
]
n_dropped = self.nframes - len(retained_indices)
if n_dropped > 0:
log.info(
f"LMDB filter:{self._filter_rule} drops {n_dropped}/"
f"{self.nframes} frames with nloc > {self._filter_rule} "
f"({self.lmdb_path})."
)
else:
retained_indices = list(range(self.nframes))

# Group retained frames by nloc. _nloc_groups only contains nlocs
# that passed the filter; its values stay as *original* LMDB frame
# indices so __getitem__(index) keeps reading the right LMDB key.
if not mixed_batch:
self._nloc_groups: dict[int, list[int]] = {}
for idx in retained_indices:
self._nloc_groups.setdefault(self._frame_nlocs[idx], []).append(idx)
else:
self._nloc_groups = {}

# Group retained frames by system id. _system_nframes is indexed by
# *original* sid and stays length _nsystems even if some systems are
# fully dropped — those entries are simply zero so auto_prob block
# slicing still parses predictably.
if self._frame_system_ids is not None:
self._system_groups: dict[int, list[int]] = {}
for idx in retained_indices:
sid = self._frame_system_ids[idx]
self._system_groups.setdefault(sid, []).append(idx)
self._system_nframes: list[int] = [
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
]
else:
self._system_groups = {0: list(retained_indices)}
self._system_nframes = [len(retained_indices)]

# nframes now reflects retained frames; __len__ returns this.
self.nframes = len(retained_indices)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

Comment thread
OutisLi marked this conversation as resolved.
Outdated
# Default batch_size used only by the index/total_batch estimate. The
# sampler always goes through get_batch_size_for_nloc for real batches.
if self._auto_rule is not None:
self.batch_size = _compute_batch_size(self._natoms, self._auto_rule)
elif self._max_rule is not None:
self.batch_size = max(1, self._max_rule // max(self._natoms, 1))
else:
self.batch_size = int(batch_size)

Expand Down Expand Up @@ -382,9 +455,19 @@ def __del__(self) -> None:
_close_lmdb(path)

def get_batch_size_for_nloc(self, nloc: int) -> int:
"""Get batch_size for a given nloc. Uses auto rule if configured."""
"""Return the per-nloc batch size for the configured rule.

- ``auto`` / ``auto:N``: ``ceil(N / nloc)`` — may overshoot the
atom budget by up to ``nloc - 1`` atoms.
- ``max:N`` / ``filter:N``: ``max(1, floor(N / nloc))`` — never
overshoots; clamps to 1 when a single frame already exceeds ``N``
atoms.
- fixed int: the same value for every nloc group.
"""
if self._auto_rule is not None:
return _compute_batch_size(nloc, self._auto_rule)
if self._max_rule is not None:
return max(1, self._max_rule // max(nloc, 1))
Comment thread
OutisLi marked this conversation as resolved.
return self.batch_size

def __len__(self) -> int:
Expand Down Expand Up @@ -538,11 +621,19 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
def print_summary(self, name: str, prob: Any) -> None:
"""Print basic dataset info."""
n_groups = len(self._nloc_groups)
if self._auto_rule is not None:
bs_str = f"auto:{self._auto_rule}"
elif self._filter_rule is not None:
bs_str = f"filter:{self._filter_rule}"
elif self._max_rule is not None:
bs_str = f"max:{self._max_rule}"
else:
bs_str = str(self.batch_size)

log.info(
f"LMDB {name}: {self.lmdb_path}, "
f"{self.nframes} frames, {n_groups} nloc groups, "
f"batch_size={'auto' if self._auto_rule else self.batch_size}, "
f"batch_size={bs_str}, "
f"mixed_batch={self.mixed_batch}"
)
# Print nloc groups in rows of ~10 for readability
Expand Down Expand Up @@ -646,6 +737,24 @@ def compute_block_targets(
stt, end, weight = part.split(":")
blocks.append((int(stt), int(end), float(weight)))

# Drop blocks that retain zero frames (can happen when ``filter:N``
# eliminates every system in a block). prob_sys_size_ext's per-block
# ``nbatch_block / sum(nbatch_block)`` would otherwise propagate NaN
# when the whole block sums to zero. An all-zero dataset yields no
# targets at all.
nonempty = [
(stt, end, weight)
for stt, end, weight in blocks
if sum(system_nframes[stt:end]) > 0
]
if not nonempty:
return []
if len(nonempty) < len(blocks):
auto_prob_style = "prob_sys_size;" + ";".join(
f"{stt}:{end}:{weight}" for stt, end, weight in nonempty
)
blocks = nonempty
Comment thread
OutisLi marked this conversation as resolved.

# Compute per-system probabilities using the standard function
sys_probs = prob_sys_size_ext(auto_prob_style, nsystems, system_nframes)

Expand Down
9 changes: 8 additions & 1 deletion deepmd/pt/utils/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ class LmdbDataset(Dataset):
type_map : list[str]
Global type map from model config.
batch_size : int or str
Batch size. Supports int, "auto", "auto:N".
Batch size rule forwarded to :class:`LmdbDataReader`. Supports:

- ``int``: fixed batch size for every nloc group.
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
(``N=32`` for bare ``"auto"``).
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` and drops
every frame whose ``nloc > N`` from the dataset.
mixed_batch : bool
If True, allow different nloc in the same batch (future).
If False (default), use SameNlocBatchSampler.
Expand Down
4 changes: 3 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3757,7 +3757,9 @@ def validation_data_args() -> list[
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
- int: all {link_sys} use the same batch size.\n\n\
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.'
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Comment thread
OutisLi marked this conversation as resolved.
Outdated
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\
Expand Down
Loading
Loading