Skip to content

Commit 6b275bf

Browse files
committed
fix neighbor-stat
1 parent 792361e commit 6b275bf

5 files changed

Lines changed: 907 additions & 757 deletions

File tree

deepmd/dpmodel/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
LmdbTestData,
1313
SameNlocBatchSampler,
1414
is_lmdb,
15+
make_neighbor_stat_data,
1516
)
1617
from .network import (
1718
EmbeddingNet,
@@ -75,6 +76,7 @@
7576
"make_embedding_network",
7677
"make_fitting_network",
7778
"make_multilayer_network",
79+
"make_neighbor_stat_data",
7880
"nlist_distinguish_types",
7981
"normalize_coord",
8082
"phys2inter",

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,69 @@ def world_size(self) -> int:
991991
return self._world_size
992992

993993

994+
def make_neighbor_stat_data(
995+
lmdb_path: str,
996+
type_map: list[str] | None,
997+
max_frames: int = 2000,
998+
) -> Any:
999+
"""Create a duck-typed DeepmdDataSystem-like object for neighbor stat from LMDB.
1000+
1001+
Samples up to *max_frames* frames, groups them by nloc, and returns an
1002+
object whose attributes satisfy the interface expected by
1003+
``NeighborStat.iterator()`` and ``UpdateSel.get_nbor_stat()``.
1004+
"""
1005+
from types import (
1006+
SimpleNamespace,
1007+
)
1008+
1009+
reader = LmdbDataReader(lmdb_path, type_map=type_map)
1010+
nframes = len(reader)
1011+
rng = np.random.RandomState(42)
1012+
if nframes > max_frames:
1013+
indices = np.sort(rng.choice(nframes, max_frames, replace=False))
1014+
else:
1015+
indices = np.arange(nframes, dtype=np.int64)
1016+
1017+
# Read sampled frames, group by nloc
1018+
nloc_frames: dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray | None]]] = {}
1019+
for idx in indices:
1020+
frame = reader[int(idx)]
1021+
atype = frame["atype"]
1022+
nloc = len(atype)
1023+
nloc_frames.setdefault(nloc, []).append(
1024+
(frame["coord"], atype, frame.get("box"))
1025+
)
1026+
1027+
# Build per-nloc data_system proxies
1028+
data_systems = []
1029+
system_dirs: list[str] = []
1030+
for nloc, frames in nloc_frames.items():
1031+
coords = np.stack([c.reshape(nloc * 3) for c, _, _ in frames])
1032+
types = np.stack([a.reshape(nloc) for _, a, _ in frames])
1033+
has_box = frames[0][2] is not None
1034+
boxes = np.stack([b.reshape(9) for _, _, b in frames]) if has_box else None
1035+
set_data = {"coord": coords, "type": types, "box": boxes}
1036+
label = f"lmdb:{nloc}atoms"
1037+
proxy = SimpleNamespace(
1038+
dirs=[label],
1039+
pbc=has_box,
1040+
mixed_type=True,
1041+
get_natoms=lambda _nloc=nloc: _nloc,
1042+
_load_set=lambda _d, _sd=set_data: _sd,
1043+
)
1044+
data_systems.append(proxy)
1045+
system_dirs.append(label)
1046+
1047+
ntypes = len(type_map) if type_map else reader._ntypes
1048+
return SimpleNamespace(
1049+
system_dirs=system_dirs,
1050+
data_systems=data_systems,
1051+
get_batch=lambda: None,
1052+
get_ntypes=lambda: ntypes,
1053+
mixed_type=True,
1054+
)
1055+
1056+
9941057
class LmdbTestData:
9951058
"""LMDB-backed data reader for dp test.
9961059

deepmd/pt/entrypoints/main.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -372,22 +372,40 @@ def train(
372372

373373
if not multi_task:
374374
type_map = config["model"].get("type_map")
375-
train_data = get_data(
376-
config["training"]["training_data"], 0, type_map, None
377-
)
375+
training_systems = config["training"]["training_data"].get("systems")
376+
if training_systems is not None and is_lmdb(training_systems):
377+
from deepmd.dpmodel.utils.lmdb_data import (
378+
make_neighbor_stat_data,
379+
)
380+
381+
train_data = make_neighbor_stat_data(training_systems, type_map)
382+
else:
383+
train_data = get_data(
384+
config["training"]["training_data"], 0, type_map, None
385+
)
378386
config["model"], min_nbor_dist = BaseModel.update_sel(
379387
train_data, type_map, config["model"]
380388
)
381389
else:
382390
min_nbor_dist = {}
383391
for model_item in config["model"]["model_dict"]:
384392
type_map = config["model"]["model_dict"][model_item].get("type_map")
385-
train_data = get_data(
386-
config["training"]["data_dict"][model_item]["training_data"],
387-
0,
388-
type_map,
389-
None,
390-
)
393+
training_systems = config["training"]["data_dict"][model_item][
394+
"training_data"
395+
].get("systems")
396+
if training_systems is not None and is_lmdb(training_systems):
397+
from deepmd.dpmodel.utils.lmdb_data import (
398+
make_neighbor_stat_data,
399+
)
400+
401+
train_data = make_neighbor_stat_data(training_systems, type_map)
402+
else:
403+
train_data = get_data(
404+
config["training"]["data_dict"][model_item]["training_data"],
405+
0,
406+
type_map,
407+
None,
408+
)
391409
config["model"]["model_dict"][model_item], min_nbor_dist[model_item] = (
392410
BaseModel.update_sel(
393411
train_data, type_map, config["model"]["model_dict"][model_item]

0 commit comments

Comments
 (0)