Skip to content

Commit 28b9300

Browse files
committed
fix efficiency
1 parent a465fab commit 28b9300

2 files changed

Lines changed: 158 additions & 75 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 103 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -475,16 +475,22 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
475475

476476
def print_summary(self, name: str, prob: Any) -> None:
477477
"""Print basic dataset info."""
478-
unique_nlocs = sorted(self._nloc_groups.keys())
479-
nloc_info = ", ".join(
480-
f"{nloc}({len(idxs)})" for nloc, idxs in sorted(self._nloc_groups.items())
481-
)
478+
n_groups = len(self._nloc_groups)
479+
482480
log.info(
483481
f"LMDB {name}: {self.lmdb_path}, "
484-
f"{self.nframes} frames, nloc groups: [{nloc_info}], "
482+
f"{self.nframes} frames, {n_groups} nloc groups, "
485483
f"batch_size={'auto' if self._auto_rule else self.batch_size}, "
486484
f"mixed_batch={self.mixed_batch}"
487485
)
486+
# Print nloc groups in rows of ~10 for readability
487+
items = [
488+
f"{nloc}({len(idxs)})" for nloc, idxs in sorted(self._nloc_groups.items())
489+
]
490+
per_row = 10
491+
for i in range(0, len(items), per_row):
492+
row = ", ".join(items[i : i + per_row])
493+
log.info(f" nloc groups: {row}")
488494

489495
def set_noise(self, noise_settings: dict[str, Any]) -> None:
490496
"""No-op for now."""
@@ -616,9 +622,11 @@ def compute_block_targets(
616622

617623
def _expand_indices_by_blocks(
618624
indices: list[int],
619-
frame_system_ids: list[int],
625+
frame_system_ids: np.ndarray,
620626
block_targets: list[tuple[list[int], int]],
621627
rng: np.random.Generator,
628+
_block_total_actual: list[int] | None = None,
629+
_sid_to_blk_arr: np.ndarray | None = None,
622630
) -> list[int]:
623631
"""Expand frame indices according to block targets.
624632
@@ -630,78 +638,95 @@ def _expand_indices_by_blocks(
630638
----------
631639
indices : list[int]
632640
Frame indices in the current nloc group.
633-
frame_system_ids : list[int]
634-
Per-frame system id for the entire dataset.
641+
frame_system_ids : np.ndarray
642+
Per-frame system id for the entire dataset (int64 array).
635643
block_targets : list[tuple[list[int], int]]
636644
Per-block (system_ids, total_target_frames).
637645
rng : np.random.Generator
638646
RNG for remainder sampling.
647+
_block_total_actual : list[int] or None
648+
Pre-computed total actual frame count per block (across all nloc
649+
groups). When provided, avoids an O(N) scan of frame_system_ids.
650+
_sid_to_blk_arr : np.ndarray or None
651+
Pre-computed system-id to block-index lookup array. When provided,
652+
avoids rebuilding the mapping for each call.
639653
640654
Returns
641655
-------
642656
list[int]
643657
Expanded indices.
644658
"""
645-
# Build sys_id -> block_idx mapping
646-
sys_to_block: dict[int, int] = {}
647-
for blk_idx, (sys_ids, _target) in enumerate(block_targets):
648-
for sid in sys_ids:
649-
sys_to_block[sid] = blk_idx
650-
651-
# Partition indices by block
652-
block_indices: dict[int, list[int]] = {i: [] for i in range(len(block_targets))}
653-
unassigned: list[int] = []
654-
for idx in indices:
655-
sid = frame_system_ids[idx]
656-
blk = sys_to_block.get(sid)
657-
if blk is not None:
658-
block_indices[blk].append(idx)
659-
else:
660-
unassigned.append(idx)
661-
662-
# Compute total actual frames across all blocks (for proportional scaling)
663-
total_actual = sum(len(block_indices[i]) for i in range(len(block_targets)))
664-
total_target_all = sum(t for _, t in block_targets)
665-
666-
expanded: list[int] = list(unassigned)
667-
668-
for blk_idx, (sys_ids, block_total_target) in enumerate(block_targets):
669-
blk_idxs = block_indices[blk_idx]
659+
n_blocks = len(block_targets)
660+
661+
# Build sys_id -> block_idx lookup array
662+
if _sid_to_blk_arr is None:
663+
sys_to_block: dict[int, int] = {}
664+
for blk_idx, (sys_ids, _target) in enumerate(block_targets):
665+
for sid in sys_ids:
666+
sys_to_block[sid] = blk_idx
667+
max_sid = max(sys_to_block.keys()) + 1 if sys_to_block else 0
668+
_sid_to_blk_arr = np.full(max_sid, -1, dtype=np.int32)
669+
for sid, blk in sys_to_block.items():
670+
_sid_to_blk_arr[sid] = blk
671+
672+
# Partition indices by block using numpy for speed
673+
idx_arr = np.asarray(indices, dtype=np.int64)
674+
sid_arr = np.asarray(frame_system_ids, dtype=np.int64)
675+
# Vectorized lookup: get block id for each index
676+
idx_sids = sid_arr[idx_arr]
677+
idx_blks = _sid_to_blk_arr[idx_sids]
678+
679+
# Pre-compute block_total_actual if not provided
680+
if _block_total_actual is None:
681+
_block_total_actual = []
682+
for sys_ids, _ in block_targets:
683+
total = sum(int(np.sum(sid_arr == sid)) for sid in sys_ids)
684+
_block_total_actual.append(total)
685+
686+
expanded_parts: list[np.ndarray] = []
687+
688+
# Unassigned indices
689+
unassigned_mask = idx_blks == -1
690+
if np.any(unassigned_mask):
691+
expanded_parts.append(idx_arr[unassigned_mask])
692+
693+
for blk_idx in range(n_blocks):
694+
blk_mask = idx_blks == blk_idx
695+
blk_idxs = idx_arr[blk_mask]
670696
n_actual = len(blk_idxs)
671697
if n_actual == 0:
672698
continue
673699

674-
# Proportional target for this nloc subset of the block
675-
# block_total_target is for the entire block; scale by the fraction
676-
# of block frames that fall in this nloc group
677-
_, block_total_target_all = block_targets[blk_idx]
678-
# Get total frames in this block across all nloc groups
679-
block_total_actual = sum(
680-
1
681-
for i in range(len(frame_system_ids))
682-
if frame_system_ids[i] in set(sys_ids)
683-
)
684-
if block_total_actual > 0:
685-
target = round(block_total_target_all * n_actual / block_total_actual)
700+
_, block_total_target = block_targets[blk_idx]
701+
block_total_act = _block_total_actual[blk_idx]
702+
703+
# Proportional target for this nloc subset
704+
if block_total_act > 0:
705+
target = round(block_total_target * n_actual / block_total_act)
686706
else:
687707
target = n_actual
688708
target = max(target, n_actual) # never shrink
689709

690710
# Full copies + remainder
691711
deficit = target - n_actual
692712
if deficit <= 0:
693-
expanded.extend(blk_idxs)
713+
expanded_parts.append(blk_idxs)
694714
else:
695715
full_copies = deficit // n_actual
696716
remainder = deficit % n_actual
697717
# Original + full copies
698-
expanded.extend(blk_idxs * (1 + full_copies))
718+
if full_copies > 0:
719+
expanded_parts.append(np.tile(blk_idxs, 1 + full_copies))
720+
else:
721+
expanded_parts.append(blk_idxs)
699722
# Remainder: sample without replacement
700723
if remainder > 0:
701724
sampled = rng.choice(blk_idxs, size=remainder, replace=False)
702-
expanded.extend(sampled.tolist())
725+
expanded_parts.append(sampled)
703726

704-
return expanded
727+
if expanded_parts:
728+
return np.concatenate(expanded_parts).tolist()
729+
return []
705730

706731

707732
def _build_all_batches(
@@ -735,12 +760,39 @@ def _build_all_batches(
735760
"""
736761
# Build per-group batches
737762
group_batches: list[list[list[int]]] = []
763+
764+
# Pre-compute expensive objects once (avoids O(N) work per nloc group)
765+
block_total_actual: list[int] | None = None
766+
sid_arr: np.ndarray | None = None
767+
sid_to_blk_arr: np.ndarray | None = None
768+
if block_targets and reader.frame_system_ids is not None:
769+
block_total_actual = []
770+
for sys_ids, _ in block_targets:
771+
total = sum(reader.system_nframes[s] for s in sys_ids)
772+
block_total_actual.append(total)
773+
# Convert frame_system_ids to numpy once
774+
sid_arr = np.array(reader.frame_system_ids, dtype=np.int64)
775+
# Build sys_id -> block_idx lookup array once
776+
sys_to_block: dict[int, int] = {}
777+
for blk_idx, (sys_ids, _target) in enumerate(block_targets):
778+
for sid in sys_ids:
779+
sys_to_block[sid] = blk_idx
780+
max_sid = max(sys_to_block.keys()) + 1 if sys_to_block else 0
781+
sid_to_blk_arr = np.full(max_sid, -1, dtype=np.int32)
782+
for sid, blk in sys_to_block.items():
783+
sid_to_blk_arr[sid] = blk
784+
738785
for nloc in sorted(reader.nloc_groups.keys()):
739786
indices = list(reader.nloc_groups[nloc])
740787
# Expand indices by block targets if provided
741-
if block_targets and reader.frame_system_ids is not None:
788+
if block_targets and sid_arr is not None:
742789
indices = _expand_indices_by_blocks(
743-
indices, reader.frame_system_ids, block_targets, rng
790+
indices,
791+
sid_arr,
792+
block_targets,
793+
rng,
794+
_block_total_actual=block_total_actual,
795+
_sid_to_blk_arr=sid_to_blk_arr,
744796
)
745797
if shuffle:
746798
rng.shuffle(indices)

deepmd/pt/utils/lmdb_dataset.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
Any,
1010
)
1111

12-
import numpy as np
1312
import torch
1413
from torch.utils.data import (
1514
DataLoader,
@@ -230,38 +229,70 @@ def print_summary(self, name: str, prob: Any) -> None:
230229
block_lines = []
231230
total_original = 0
232231
total_target = 0
232+
# Pre-compute block_total_actual for proportional scaling
233+
block_total_actual: list[int] = []
233234
for sys_ids, target in self._block_targets:
234235
actual = sum(reader.system_nframes[s] for s in sys_ids)
236+
block_total_actual.append(actual)
235237
total_original += actual
236238
total_target += target
237-
sys_str = ",".join(str(s) for s in sys_ids)
238-
block_lines.append(f"sys[{sys_str}]:{actual}->{target}")
239-
# Expanded nloc groups (simulate one expansion to get counts)
240-
from deepmd.dpmodel.utils.lmdb_data import (
241-
_expand_indices_by_blocks,
242-
)
239+
# Compact range notation: sys[0-146] instead of sys[0,1,2,...,146]
240+
if len(sys_ids) > 3:
241+
sys_str = f"{sys_ids[0]}-{sys_ids[-1]}"
242+
else:
243+
sys_str = ",".join(str(s) for s in sys_ids)
244+
ratio = target / actual if actual > 0 else 0
245+
block_lines.append(
246+
f"sys[{sys_str}]({len(sys_ids)}sys): "
247+
f"{actual}->{target} (x{ratio:.2f})"
248+
)
243249

250+
# Build sys_id -> block_idx mapping
251+
sys_to_block: dict[int, int] = {}
252+
for blk_idx, (sys_ids, _) in enumerate(self._block_targets):
253+
for sid in sys_ids:
254+
sys_to_block[sid] = blk_idx
255+
256+
# Compute expanded nloc counts analytically (no actual expansion)
244257
expanded_nloc_info = {}
245258
for nloc, indices in sorted(reader.nloc_groups.items()):
246-
rng = np.random.default_rng(0)
247-
expanded = _expand_indices_by_blocks(
248-
list(indices),
249-
reader.frame_system_ids,
250-
self._block_targets,
251-
rng,
252-
)
253-
expanded_nloc_info[nloc] = len(expanded)
254-
nloc_str = ", ".join(
255-
f"{nloc}({orig}->{expanded_nloc_info[nloc]})"
256-
for nloc, orig in sorted(
257-
(n, len(idx)) for n, idx in reader.nloc_groups.items()
258-
)
259-
)
259+
if reader.frame_system_ids is None:
260+
expanded_nloc_info[nloc] = len(indices)
261+
continue
262+
# Count indices per block in this nloc group
263+
blk_counts: dict[int, int] = {}
264+
unassigned = 0
265+
for idx in indices:
266+
sid = reader.frame_system_ids[idx]
267+
blk = sys_to_block.get(sid)
268+
if blk is not None:
269+
blk_counts[blk] = blk_counts.get(blk, 0) + 1
270+
else:
271+
unassigned += 1
272+
expanded = unassigned
273+
for blk_idx, (_, blk_target) in enumerate(self._block_targets):
274+
n_actual = blk_counts.get(blk_idx, 0)
275+
if n_actual == 0:
276+
continue
277+
bta = block_total_actual[blk_idx]
278+
if bta > 0:
279+
t = max(round(blk_target * n_actual / bta), n_actual)
280+
else:
281+
t = n_actual
282+
expanded += t
283+
expanded_nloc_info[nloc] = expanded
284+
285+
total_expanded = sum(expanded_nloc_info.values())
286+
n_groups = len(reader.nloc_groups)
287+
ratio_all = total_expanded / total_original if total_original > 0 else 0
288+
260289
log.info(
261-
f"LMDB {name} auto_prob: {total_original}->{total_target} frames, "
262-
f"blocks: [{', '.join(block_lines)}], "
263-
f"nloc groups: [{nloc_str}]"
290+
f"LMDB {name} auto_prob: "
291+
f"{total_original}->{total_expanded} frames (x{ratio_all:.2f}), "
292+
f"{n_groups} nloc groups, {len(self._block_targets)} blocks:"
264293
)
294+
for bl in block_lines:
295+
log.info(f" {bl}")
265296

266297
def set_noise(self, noise_settings: dict[str, Any]) -> None:
267298
self._reader.set_noise(noise_settings)

0 commit comments

Comments
 (0)