Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 168 additions & 29 deletions deepmd/dpmodel/utils/lmdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,22 @@ class LmdbDataReader:
type_map : list[str]
Global type map from model config.
batch_size : int or str
Batch size. Supports int, "auto", "auto:N".
Batch size rule used to derive per-nloc batch sizes. Supports:

- ``int``: fixed, identical batch size for every nloc group.
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
(``N=32`` for bare ``"auto"``). Acts as a *lower* bound —
each batch has at least ``N`` atoms, but may exceed ``N``
by up to ``nloc - 1``.
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
Acts as an *upper* bound for groups with ``nloc <= N``
(batch has at most ``N`` atoms). For groups with
``nloc > N`` the ``max(1, ...)`` floor kicks in: ``bsi=1``
and a single-frame batch still carries ``nloc`` atoms,
which exceeds ``N``.
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` **and**
drops every frame whose ``nloc > N`` from the dataset. By
construction every retained batch has at most ``N`` atoms.
mixed_batch : bool
If True, allow different nloc in the same batch (future).
If False (default), enforce same-nloc-per-batch.
Expand Down Expand Up @@ -283,51 +298,129 @@ def __init__(

# Scan per-frame nloc only when needed for same-nloc batching.
# For mixed_batch=True, skip the scan entirely (future: padding handles it).
# ``orig_frame_nlocs`` / ``orig_frame_system_ids`` are indexed by the
# *original* LMDB frame index. After a potential ``filter:N`` drop we
# rebuild ``self._frame_nlocs`` / ``self._frame_system_ids`` so they
# are parallel arrays over the *dataset* index space (0..len(self));
# the dataset-to-original mapping lives in ``self._retained_keys``.
if not mixed_batch:
# Fast path: use pre-computed frame_nlocs from metadata if available.
# Falls back to scanning each frame's atom_types shape (~10 us/frame).
meta_nlocs = meta.get("frame_nlocs")
if meta_nlocs is not None:
self._frame_nlocs = [int(n) for n in meta_nlocs]
orig_frame_nlocs = [int(n) for n in meta_nlocs]
else:
self._frame_nlocs = _scan_frame_nlocs(
orig_frame_nlocs = _scan_frame_nlocs(
self._env, self.nframes, self._frame_fmt, self._natoms
)
self._nloc_groups: dict[int, list[int]] = {}
for idx, nloc in enumerate(self._frame_nlocs):
self._nloc_groups.setdefault(nloc, []).append(idx)
else:
self._frame_nlocs = []
self._nloc_groups = {}
orig_frame_nlocs = []

# Parse frame_system_ids for auto_prob support
# Parse frame_system_ids for auto_prob support. ``_nsystems`` must stay
# at ``max(original_sid) + 1`` even after filter:N so that user-facing
# auto_prob block slicing (e.g. ``prob_sys_size;0:284:0.5;284:842:0.5``)
# keeps its meaning across filter thresholds.
meta_sys_ids = meta.get("frame_system_ids")
if meta_sys_ids is not None:
self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
self._nsystems = max(self._frame_system_ids) + 1
self._system_groups: dict[int, list[int]] = {}
for idx, sid in enumerate(self._frame_system_ids):
self._system_groups.setdefault(sid, []).append(idx)
self._system_nframes: list[int] = [
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
]
orig_frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids]
self._nsystems = max(orig_frame_system_ids) + 1
else:
self._frame_system_ids = None
orig_frame_system_ids = None
self._nsystems = 1
self._system_groups = {0: list(range(self.nframes))}
self._system_nframes = [self.nframes]

# Parse batch_size spec
# Parse batch_size spec. ``auto_rule`` and ``max_rule`` are mutually
# exclusive; ``filter_rule`` implies ``max_rule`` plus dropping frames
# whose nloc exceeds the threshold.
self._auto_rule: int | None = None
self._max_rule: int | None = None
self._filter_rule: int | None = None
if isinstance(batch_size, str):
if batch_size == "auto":
self._auto_rule = 32
elif batch_size.startswith("auto:"):
self._auto_rule = int(batch_size.split(":")[1])
elif batch_size.startswith("max:"):
self._max_rule = int(batch_size.split(":")[1])
elif batch_size.startswith("filter:"):
self._filter_rule = int(batch_size.split(":")[1])
self._max_rule = self._filter_rule
else:
self._auto_rule = 32
# Default batch_size uses first frame's nloc (for total_batch estimate)
raise ValueError(
f"Unsupported batch_size {batch_size!r}. "
"Expected int, 'auto', 'auto:N', 'max:N', or 'filter:N'."
)
Comment thread
OutisLi marked this conversation as resolved.

# Determine which original-index frames survive the filter. Without
# ``filter:N`` every frame is retained. ``mixed_batch=True`` has no
# per-frame nloc info to filter against, so filter:N is a no-op there.
if self._filter_rule is not None and not mixed_batch:
retained_keys = [
i for i, n in enumerate(orig_frame_nlocs) if n <= self._filter_rule
]
n_dropped = self.nframes - len(retained_keys)
if n_dropped > 0:
log.info(
f"LMDB filter:{self._filter_rule} drops {n_dropped}/"
f"{self.nframes} frames with nloc > {self._filter_rule} "
f"({self.lmdb_path})."
)
else:
retained_keys = list(range(self.nframes))
Comment thread
OutisLi marked this conversation as resolved.

# Dataset-index → original LMDB frame key. ``__getitem__`` looks up
# this table so that ``reader[i]`` is a valid LMDB read for every
# ``0 <= i < len(reader)``, no matter how many frames were filtered.
self._retained_keys: list[int] = retained_keys

# Re-key _frame_nlocs / _frame_system_ids into the dataset-index
# space so that every downstream consumer (nloc_groups, system_groups,
# SameNlocBatchSampler, _expand_indices_by_blocks) operates in a
# single, self-consistent indexing scheme.
if not mixed_batch:
self._frame_nlocs = [orig_frame_nlocs[k] for k in retained_keys]
else:
self._frame_nlocs = []

if orig_frame_system_ids is not None:
self._frame_system_ids: list[int] | None = [
orig_frame_system_ids[k] for k in retained_keys
]
else:
self._frame_system_ids = None

# Group retained frames by nloc using dataset indices (0..len-1).
if not mixed_batch:
self._nloc_groups: dict[int, list[int]] = {}
for ds_idx, nloc in enumerate(self._frame_nlocs):
self._nloc_groups.setdefault(nloc, []).append(ds_idx)
else:
self._nloc_groups = {}

# Group retained frames by original system id; the sid numbering is
# preserved (no compression) so user-facing auto_prob slices stay
# meaningful across filter thresholds. Fully-dropped systems appear
# as zero-frame entries in ``_system_nframes``.
if self._frame_system_ids is not None:
self._system_groups: dict[int, list[int]] = {}
for ds_idx, sid in enumerate(self._frame_system_ids):
self._system_groups.setdefault(sid, []).append(ds_idx)
self._system_nframes: list[int] = [
len(self._system_groups.get(i, [])) for i in range(self._nsystems)
]
else:
self._system_groups = {0: list(range(len(retained_keys)))}
self._system_nframes = [len(retained_keys)]

# nframes now reflects retained frames; __len__ returns this and the
# valid index domain for __getitem__ is [0, self.nframes).
self.nframes = len(retained_keys)

# Default batch_size used only by the index/total_batch estimate. The
# sampler always goes through get_batch_size_for_nloc for real batches.
if self._auto_rule is not None:
self.batch_size = _compute_batch_size(self._natoms, self._auto_rule)
elif self._max_rule is not None:
self.batch_size = max(1, self._max_rule // max(self._natoms, 1))
else:
self.batch_size = int(batch_size)

Expand Down Expand Up @@ -382,20 +475,40 @@ def __del__(self) -> None:
_close_lmdb(path)

def get_batch_size_for_nloc(self, nloc: int) -> int:
"""Get batch_size for a given nloc. Uses auto rule if configured."""
"""Return the per-nloc batch size for the configured rule.

- ``auto`` / ``auto:N``: ``ceil(N / nloc)`` — may overshoot the
atom budget by up to ``nloc - 1`` atoms.
- ``max:N`` / ``filter:N``: ``max(1, floor(N / nloc))`` — never
overshoots; clamps to 1 when a single frame already exceeds ``N``
atoms.
- fixed int: the same value for every nloc group.
"""
if self._auto_rule is not None:
return _compute_batch_size(nloc, self._auto_rule)
if self._max_rule is not None:
return max(1, self._max_rule // max(nloc, 1))
Comment thread
OutisLi marked this conversation as resolved.
return self.batch_size

def __len__(self) -> int:
return self.nframes

def __getitem__(self, index: int) -> dict[str, Any]:
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays."""
key = format(index, self._frame_fmt).encode()
"""Read frame from LMDB, decode, remap keys, return dict of numpy arrays.

``index`` is a dataset-level index in ``[0, len(self))``. Under
``filter:N`` the LMDB key space may have gaps (dropped frames), so
we translate through ``self._retained_keys`` before hitting LMDB.
"""
if index < 0 or index >= self.nframes:
raise IndexError(f"dataset index {index} out of range [0, {self.nframes})")
original_key = self._retained_keys[index]
key = format(original_key, self._frame_fmt).encode()
raw = self._txn.get(key)
if raw is None:
raise IndexError(f"Frame {index} not found in LMDB")
raise IndexError(
f"Frame {original_key} not found in LMDB (dataset index {index})"
)
frame = _decode_frame(raw)
frame = _remap_keys(frame)

Expand Down Expand Up @@ -524,7 +637,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
np.float32(1.0) if extra_key in frame else np.float32(0.0)
)

frame["fid"] = index
frame["fid"] = original_key

return frame

Expand All @@ -538,11 +651,19 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
def print_summary(self, name: str, prob: Any) -> None:
"""Print basic dataset info."""
n_groups = len(self._nloc_groups)
if self._auto_rule is not None:
bs_str = f"auto:{self._auto_rule}"
elif self._filter_rule is not None:
bs_str = f"filter:{self._filter_rule}"
elif self._max_rule is not None:
bs_str = f"max:{self._max_rule}"
else:
bs_str = str(self.batch_size)

log.info(
f"LMDB {name}: {self.lmdb_path}, "
f"{self.nframes} frames, {n_groups} nloc groups, "
f"batch_size={'auto' if self._auto_rule else self.batch_size}, "
f"batch_size={bs_str}, "
f"mixed_batch={self.mixed_batch}"
)
# Print nloc groups in rows of ~10 for readability
Expand Down Expand Up @@ -646,6 +767,24 @@ def compute_block_targets(
stt, end, weight = part.split(":")
blocks.append((int(stt), int(end), float(weight)))

# Drop blocks that retain zero frames (can happen when ``filter:N``
# eliminates every system in a block). prob_sys_size_ext's per-block
# ``nbatch_block / sum(nbatch_block)`` would otherwise propagate NaN
# when the whole block sums to zero. An all-zero dataset yields no
# targets at all.
nonempty = [
(stt, end, weight)
for stt, end, weight in blocks
if sum(system_nframes[stt:end]) > 0
]
if not nonempty:
return []
if len(nonempty) < len(blocks):
auto_prob_style = "prob_sys_size;" + ";".join(
f"{stt}:{end}:{weight}" for stt, end, weight in nonempty
)
blocks = nonempty
Comment thread
OutisLi marked this conversation as resolved.

# Compute per-system probabilities using the standard function
sys_probs = prob_sys_size_ext(auto_prob_style, nsystems, system_nframes)

Expand Down
9 changes: 8 additions & 1 deletion deepmd/pt/utils/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ class LmdbDataset(Dataset):
type_map : list[str]
Global type map from model config.
batch_size : int or str
Batch size. Supports int, "auto", "auto:N".
Batch size rule forwarded to :class:`LmdbDataReader`. Supports:

- ``int``: fixed batch size for every nloc group.
- ``"auto"`` / ``"auto:N"``: ``ceil(N / nloc)`` per nloc group
(``N=32`` for bare ``"auto"``).
- ``"max:N"``: ``max(1, floor(N / nloc))`` per nloc group.
- ``"filter:N"``: same per-nloc formula as ``"max:N"`` and drops
every frame whose ``nloc > N`` from the dataset.
mixed_batch : bool
If True, allow different nloc in the same batch (future).
If False (default), use SameNlocBatchSampler.
Expand Down
8 changes: 5 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3677,8 +3677,8 @@ def training_data_args() -> list[
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.\n\n\
If MPI is used, the value should be considered as the batch size per task.'
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
Expand Down Expand Up @@ -3757,7 +3757,9 @@ def validation_data_args() -> list[
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
- int: all {link_sys} use the same batch size.\n\n\
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
- string "max:N": automatically determines the batch size so that `batch_size * natoms` is at most `N`. `natoms` is the per-system atom count for npy data and the per-frame nloc for LMDB data. When a single system/frame already has more than `N` atoms, the batch size clamps to 1 and that batch will exceed `N`.\n\n\
- string "filter:N": the same as `"max:N"` but additionally drops data whose atom count exceeds `N`. For npy data this removes whole systems with natoms > `N`; for LMDB data this removes individual frames with nloc > `N`.'
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\
Expand Down
Loading
Loading