Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions deepmd/dpmodel/utils/lmdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,11 @@ def mixed_type(self) -> bool:
"""LMDB datasets are always mixed_type (frames may have different compositions)."""
return True

@property
def type_map(self) -> list[str]:
"""Model-side type map used when constructing the reader."""
return self._type_map

@property
def nloc_groups(self) -> dict[int, list[int]]:
"""Nloc → list of frame indices."""
Expand Down Expand Up @@ -608,6 +613,41 @@ def system_nframes(self) -> list[int]:
return self._system_nframes


def collate_lmdb_frames(frames: list[dict[str, Any]]) -> dict[str, Any]:
"""Stack a list of per-frame dicts into a single batch dict.

Backend-agnostic via ``array_api_compat``: works for numpy, torch, jax,
etc. The array library is inferred from the first frame's ``coord``.

Conventions match :func:`deepmd.dpmodel.utils.batch.normalize_batch`:
``find_*`` flags are taken from the first frame (constant within a
batch); ``fid`` is collected as a list; ``type`` is dropped (callers
should already use ``atype``); other arrays are stacked along axis 0.
A ``sid`` placeholder is appended.
"""
import array_api_compat

if not frames:
raise ValueError("collate_lmdb_frames requires at least one frame")

xp = array_api_compat.array_namespace(frames[0]["coord"])
dev = array_api_compat.device(frames[0]["coord"])
out: dict[str, Any] = {}
for key in frames[0]:
if key.startswith("find_"):
out[key] = frames[0][key]
elif key == "fid":
out[key] = [f[key] for f in frames]
elif key == "type":
continue
elif frames[0][key] is None:
out[key] = None
else:
out[key] = xp.stack([f[key] for f in frames])
out["sid"] = xp.asarray([0], dtype=xp.int64, device=dev)
return out


def compute_block_targets(
auto_prob_style: str,
nsystems: int,
Expand Down
49 changes: 23 additions & 26 deletions deepmd/pt/utils/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@
Dataset,
Sampler,
)
from torch.utils.data._utils.collate import (
collate_tensor_fn,
)

from deepmd.dpmodel.utils.lmdb_data import (
LmdbDataReader,
LmdbTestData,
SameNlocBatchSampler,
collate_lmdb_frames,
compute_block_targets,
is_lmdb,
)
Expand All @@ -42,13 +40,17 @@


def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]:
"""Collate a list of frame dicts into a batch dict.
"""Collate a list of frame dicts into a torch batch dict.

All frames in the batch must have the same nloc (enforced by
SameNlocBatchSampler when mixed_batch=False).
Pre-converts per-frame numpy arrays to CPU torch tensors (zero-copy when
dtype matches) and delegates stacking to the backend-agnostic
:func:`collate_lmdb_frames`. With torch tensors as input, the shared
collate yields a torch dict (``sid`` becomes a torch tensor automatically
via ``array_api_compat``).

For mixed_batch=True, this function would need padding + mask.
Currently raises NotImplementedError for that case.
All frames in the batch must have the same nloc (enforced by
SameNlocBatchSampler when mixed_batch=False). For mixed_batch=True,
raises NotImplementedError.
"""
if len(batch) > 1:
atypes = [d.get("atype") for d in batch if d.get("atype") is not None]
Expand All @@ -59,24 +61,19 @@ def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]:
"Padding + mask in collate_fn needed."
)

example = batch[0]
result: dict[str, Any] = {}
for key in example:
if "find_" in key:
result[key] = batch[0][key]
elif key == "fid":
result[key] = [d[key] for d in batch]
elif key == "type":
continue
elif batch[0][key] is None:
result[key] = None
else:
with torch.device("cpu"):
result[key] = collate_tensor_fn(
[torch.as_tensor(d[key]) for d in batch]
)
result["sid"] = torch.tensor([0], dtype=torch.long, device="cpu")
return result
with torch.device("cpu"):
torch_frames: list[dict[str, Any]] = []
for f in batch:
tf: dict[str, Any] = {}
for key, val in f.items():
if key.startswith("find_") or key == "fid" or key == "type":
tf[key] = val
elif val is None:
tf[key] = None
else:
tf[key] = torch.as_tensor(val)
torch_frames.append(tf)
return collate_lmdb_frames(torch_frames)


class _SameNlocBatchSamplerTorch(Sampler):
Expand Down
147 changes: 97 additions & 50 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@

import h5py

from deepmd.dpmodel.utils.lmdb_data import (
is_lmdb,
)
from deepmd.pt_expt.train import (
training,
)
from deepmd.pt_expt.utils.lmdb_dataset import (
LmdbDataSystem,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand All @@ -35,6 +41,84 @@
log = logging.getLogger(__name__)


def _detect_lmdb_path(systems_raw: Any) -> str | None:
"""Return the LMDB path when ``systems_raw`` is a scalar LMDB string.

Returns ``None`` for non-LMDB inputs. Raises ``ValueError`` if
``systems_raw`` is a list containing any LMDB path, so both
``_get_neighbor_stat_data`` and ``_build_data_system`` fail with the
same clear message instead of the opaque error from
:func:`process_systems` / :class:`DeepmdData`.
"""
if isinstance(systems_raw, str) and is_lmdb(systems_raw):
return systems_raw
if isinstance(systems_raw, list) and any(
isinstance(s, str) and is_lmdb(s) for s in systems_raw
):
raise ValueError(
"LMDB datasets must be passed as a scalar 'systems' string "
"(e.g. 'systems': '/path/to/data.lmdb'); list-form systems "
"with LMDB paths are not supported."
)
return None


def _get_neighbor_stat_data(
dataset_params: dict[str, Any],
type_map: list[str] | None,
) -> Any:
"""Return a data proxy suitable for ``BaseModel.update_sel`` (neighbor stat).

Routes a scalar LMDB ``systems`` path through dpmodel's
``make_neighbor_stat_data``; falls back to the legacy ``get_data`` for
npy/HDF5 directories.
"""
lmdb_path = _detect_lmdb_path(dataset_params.get("systems"))
if lmdb_path is not None:
from deepmd.dpmodel.utils.lmdb_data import (
make_neighbor_stat_data,
)

return make_neighbor_stat_data(lmdb_path, type_map)
return get_data(dataset_params, 0, type_map, None)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _build_data_system(
dataset_params: dict[str, Any],
type_map: list[str],
seed: int | None = None,
) -> DeepmdDataSystem | LmdbDataSystem:
"""Build a data system from dataset config, routing LMDB paths to LmdbDataSystem.

A scalar ``systems`` value pointing at an LMDB directory triggers the
LMDB adapter; otherwise we fall through to the legacy
:class:`DeepmdDataSystem` path with system expansion.
"""
systems_raw = dataset_params["systems"]
lmdb_path = _detect_lmdb_path(systems_raw)
if lmdb_path is not None:
return LmdbDataSystem(
lmdb_path=lmdb_path,
type_map=type_map,
batch_size=dataset_params["batch_size"],
auto_prob_style=dataset_params.get("auto_prob"),
seed=seed,
)
systems = process_systems(
systems_raw,
patterns=dataset_params.get("rglob_patterns", None),
)
return DeepmdDataSystem(
systems=systems,
batch_size=dataset_params["batch_size"],
test_size=1,
type_map=type_map,
trn_all_set=True,
sys_probs=dataset_params.get("sys_probs", None),
auto_prob_style=dataset_params.get("auto_prob", "prob_sys_size"),
)


def get_trainer(
config: dict[str, Any],
init_model: str | None = None,
Expand All @@ -48,39 +132,23 @@ def get_trainer(
training_params = config["training"]
multi_task = "model_dict" in model_params

data_seed = training_params.get("seed", None)

if not multi_task:
type_map = model_params["type_map"]

# ----- training data ------------------------------------------------
training_dataset_params = training_params["training_data"]
training_systems = process_systems(
training_dataset_params["systems"],
patterns=training_dataset_params.get("rglob_patterns", None),
)
train_data = DeepmdDataSystem(
systems=training_systems,
batch_size=training_dataset_params["batch_size"],
test_size=1,
type_map=type_map,
trn_all_set=True,
sys_probs=training_dataset_params.get("sys_probs", None),
auto_prob_style=training_dataset_params.get("auto_prob", "prob_sys_size"),
train_data = _build_data_system(
training_dataset_params, type_map, seed=data_seed
)

# ----- validation data ----------------------------------------------
validation_data = None
validation_dataset_params = training_params.get("validation_data", None)
if validation_dataset_params is not None:
val_systems = process_systems(
validation_dataset_params["systems"],
patterns=validation_dataset_params.get("rglob_patterns", None),
)
validation_data = DeepmdDataSystem(
systems=val_systems,
batch_size=validation_dataset_params["batch_size"],
test_size=1,
type_map=type_map,
trn_all_set=True,
validation_data = _build_data_system(
validation_dataset_params, type_map, seed=data_seed
)

# ----- stat file path -----------------------------------------------
Expand All @@ -103,34 +171,15 @@ def get_trainer(
data_params = training_params["data_dict"][model_key]

# training data
td_params = data_params["training_data"]
training_systems = process_systems(
td_params["systems"],
patterns=td_params.get("rglob_patterns", None),
)
train_data[model_key] = DeepmdDataSystem(
systems=training_systems,
batch_size=td_params["batch_size"],
test_size=1,
type_map=type_map,
trn_all_set=True,
sys_probs=td_params.get("sys_probs", None),
auto_prob_style=td_params.get("auto_prob", "prob_sys_size"),
train_data[model_key] = _build_data_system(
data_params["training_data"], type_map, seed=data_seed
)

# validation data
vd_params = data_params.get("validation_data", None)
if vd_params is not None:
val_systems = process_systems(
vd_params["systems"],
patterns=vd_params.get("rglob_patterns", None),
)
validation_data[model_key] = DeepmdDataSystem(
systems=val_systems,
batch_size=vd_params["batch_size"],
test_size=1,
type_map=type_map,
trn_all_set=True,
validation_data[model_key] = _build_data_system(
vd_params, type_map, seed=data_seed
)
else:
validation_data[model_key] = None
Expand Down Expand Up @@ -261,20 +310,18 @@ def train(

if not multi_task:
type_map = config["model"].get("type_map")
train_data = get_data(
config["training"]["training_data"], 0, type_map, None
train_data = _get_neighbor_stat_data(
config["training"]["training_data"], type_map
)
config["model"], _ = BaseModel.update_sel(
train_data, type_map, config["model"]
)
else:
for model_key in config["model"]["model_dict"]:
type_map = config["model"]["model_dict"][model_key]["type_map"]
train_data = get_data(
train_data = _get_neighbor_stat_data(
config["training"]["data_dict"][model_key]["training_data"],
0,
type_map,
None,
)
config["model"]["model_dict"][model_key], _ = BaseModel.update_sel(
train_data,
Expand Down
Loading
Loading