Skip to content

Commit b615276

Browse files
committed
feat(pt): full validation support lmdb format
1 parent d14233e commit b615276

4 files changed

Lines changed: 114 additions & 27 deletions

File tree

deepmd/dpmodel/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DistributedSameNlocBatchSampler,
1111
LmdbDataReader,
1212
LmdbTestData,
13+
LmdbTestDataNlocView,
1314
SameNlocBatchSampler,
1415
is_lmdb,
1516
make_neighbor_stat_data,
@@ -58,6 +59,7 @@
5859
"FittingNet",
5960
"LmdbDataReader",
6061
"LmdbTestData",
62+
"LmdbTestDataNlocView",
6163
"NativeLayer",
6264
"NativeNet",
6365
"NetworkCollection",

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,25 @@ def add(
12861286
"dtype": dtype,
12871287
}
12881288

1289+
def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
1290+
"""Register expected keys from ``DataRequirementItem`` objects.
1291+
1292+
Mirrors :meth:`LmdbDataReader.add_data_requirement` so the same
1293+
requirement list can be forwarded to both the training reader and
1294+
the full-validation test data.
1295+
"""
1296+
for item in data_requirement:
1297+
self.add(
1298+
item["key"],
1299+
ndof=item["ndof"],
1300+
atomic=item["atomic"],
1301+
must=item["must"],
1302+
high_prec=item["high_prec"],
1303+
repeat=item["repeat"],
1304+
default=item["default"],
1305+
dtype=item["dtype"],
1306+
)
1307+
12891308
def _resolve_dtype(self, key: str) -> np.dtype:
12901309
"""Resolve target dtype for a key using registered requirements."""
12911310
if key in self._requirements:
@@ -1444,6 +1463,28 @@ def _stack_frames(
14441463
return result
14451464

14461465

1466+
class LmdbTestDataNlocView:
1467+
"""Thin wrapper exposing a fixed-``nloc`` view of :class:`LmdbTestData`.
1468+
1469+
The underlying :class:`LmdbTestData` groups frames by atom count. This
1470+
view fixes one ``nloc`` group, so ``get_test()`` returns only the frames
1471+
with that atom count and all other attributes (``pbc``, ``mixed_type``,
1472+
…) are forwarded to the underlying object. It lets downstream consumers
1473+
that expect a ``DeepmdData``-style system (one fixed natoms per
1474+
``get_test()``) work on mixed-nloc LMDB datasets without changes.
1475+
"""
1476+
1477+
def __init__(self, lmdb_test_data: "LmdbTestData", nloc: int) -> None:
1478+
self._inner = lmdb_test_data
1479+
self._nloc = nloc
1480+
1481+
def __getattr__(self, name: str) -> Any:
1482+
return getattr(self._inner, name)
1483+
1484+
def get_test(self) -> dict[str, Any]:
1485+
return self._inner.get_test(nloc=self._nloc)
1486+
1487+
14471488
def merge_lmdb(
14481489
src_paths: list[str],
14491490
dst_path: str,

deepmd/entrypoints/test.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from deepmd.dpmodel.utils.lmdb_data import (
1919
LmdbTestData,
20+
LmdbTestDataNlocView,
2021
is_lmdb,
2122
)
2223
from deepmd.infer.deep_dipole import (
@@ -77,24 +78,6 @@
7778
log = logging.getLogger(__name__)
7879

7980

80-
class _LmdbTestDataNlocView:
81-
"""Thin wrapper that makes LmdbTestData.get_test() return a specific nloc group.
82-
83-
Delegates all attributes to the underlying LmdbTestData, but get_test()
84-
returns only frames with the specified nloc.
85-
"""
86-
87-
def __init__(self, lmdb_test_data: LmdbTestData, nloc: int) -> None:
88-
self._inner = lmdb_test_data
89-
self._nloc = nloc
90-
91-
def __getattr__(self, name: str) -> Any:
92-
return getattr(self._inner, name)
93-
94-
def get_test(self) -> dict:
95-
return self._inner.get_test(nloc=self._nloc)
96-
97-
9881
def test(
9982
*,
10083
model: str,
@@ -221,7 +204,7 @@ def test(
221204
for nloc_val in nloc_keys:
222205
label = f"{system} [nloc={nloc_val}]" if len(nloc_keys) > 1 else system
223206
# Create a thin wrapper that returns only this nloc group
224-
data_items.append((_LmdbTestDataNlocView(lmdb_data, nloc_val), label))
207+
data_items.append((LmdbTestDataNlocView(lmdb_data, nloc_val), label))
225208
else:
226209
data = DeepmdData(
227210
system,

deepmd/pt/train/validation.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
import torch.distributed as dist
2525

2626
from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
27+
from deepmd.dpmodel.utils.lmdb_data import (
28+
LmdbTestData,
29+
LmdbTestDataNlocView,
30+
)
2731
from deepmd.pt.utils.auto_batch_size import (
2832
AutoBatchSize,
2933
)
@@ -35,6 +39,9 @@
3539
GLOBAL_PT_FLOAT_PRECISION,
3640
RESERVED_PRECISION_DICT,
3741
)
42+
from deepmd.pt.utils.lmdb_dataset import (
43+
LmdbDataset,
44+
)
3845
from deepmd.pt.utils.utils import (
3946
to_torch_tensor,
4047
)
@@ -55,6 +62,10 @@
5562
log = logging.getLogger(__name__)
5663

5764
if TYPE_CHECKING:
65+
from collections.abc import (
66+
Iterator,
67+
)
68+
5869
from deepmd.utils.data import (
5970
DeepmdData,
6071
)
@@ -229,6 +240,12 @@ def __init__(
229240
if self.rank == 0:
230241
self._initialize_best_checkpoints(restart_training=restart_training)
231242

243+
# Lazily-populated full test snapshot for LMDB validation. Mixed-nloc
244+
# LMDB datasets cannot be stacked as a single (nframes, natoms*3)
245+
# tensor, so we materialize frames grouped by nloc the first time
246+
# full validation runs and reuse the snapshot on subsequent calls.
247+
self._lmdb_test_data: LmdbTestData | None = None
248+
232249
def should_run(self, display_step: int) -> bool:
233250
"""Check whether the current step should trigger full validation."""
234251
if not self.enabled or self.start_step is None:
@@ -348,14 +365,10 @@ def evaluate_all_systems(self) -> dict[str, float]:
348365
if torch.cuda.is_available():
349366
torch.cuda.empty_cache()
350367

351-
system_metrics = []
352-
for dataset in self.validation_data.systems:
353-
if not isinstance(dataset, DeepmdDataSetForLoader):
354-
raise TypeError(
355-
"Full validation expects each dataset in validation_data.systems "
356-
f"to be DeepmdDataSetForLoader, got {type(dataset)!r}."
357-
)
358-
system_metrics.append(self._evaluate_system(dataset.data_system))
368+
system_metrics = [
369+
self._evaluate_system(data_system)
370+
for data_system in self._iter_validation_data_systems()
371+
]
359372

360373
aggregated = weighted_average([metric for metric in system_metrics if metric])
361374
return {
@@ -364,6 +377,54 @@ def evaluate_all_systems(self) -> dict[str, float]:
364377
if metric_key in aggregated
365378
}
366379

380+
def _iter_validation_data_systems(self) -> Iterator[Any]:
381+
"""Yield ``DeepmdData``-like systems to evaluate in this run.
382+
383+
- For ``DpLoaderSet``-style validation data, each entry in
384+
``validation_data.systems`` is a :class:`DeepmdDataSetForLoader`,
385+
and we forward its underlying ``DeepmdData`` instance.
386+
- For ``LmdbDataset`` validation data, we lazily materialize a
387+
:class:`LmdbTestData` snapshot (cached across calls) and yield one
388+
:class:`LmdbTestDataNlocView` per ``nloc`` group, so mixed-nloc
389+
frames can be stacked and evaluated group by group.
390+
"""
391+
validation_data = self.validation_data
392+
if isinstance(validation_data, LmdbDataset):
393+
lmdb_test_data = self._get_lmdb_test_data_snapshot(validation_data)
394+
for nloc in sorted(lmdb_test_data.nloc_groups.keys()):
395+
yield LmdbTestDataNlocView(lmdb_test_data, nloc)
396+
return
397+
398+
for dataset in validation_data.systems:
399+
if not isinstance(dataset, DeepmdDataSetForLoader):
400+
raise TypeError(
401+
"Full validation expects each dataset in validation_data.systems "
402+
f"to be DeepmdDataSetForLoader, got {type(dataset)!r}."
403+
)
404+
yield dataset.data_system
405+
406+
def _get_lmdb_test_data_snapshot(self, lmdb_dataset: LmdbDataset) -> LmdbTestData:
407+
"""Build (once) and return the cached LMDB test snapshot.
408+
409+
Reuses the ``type_map`` and previously-registered
410+
``DataRequirementItem`` entries on ``LmdbDataset._reader`` so that
411+
the full-validation snapshot sees exactly the same fields and
412+
dtypes as training batches.
413+
"""
414+
if self._lmdb_test_data is not None:
415+
return self._lmdb_test_data
416+
417+
reader = lmdb_dataset._reader
418+
self._lmdb_test_data = LmdbTestData(
419+
lmdb_dataset.lmdb_path,
420+
type_map=list(reader._type_map),
421+
shuffle_test=False,
422+
)
423+
data_requirements = list(reader._data_requirements.values())
424+
if data_requirements:
425+
self._lmdb_test_data.add_data_requirement(data_requirements)
426+
return self._lmdb_test_data
427+
367428
def _evaluate_system(
368429
self, data_system: DeepmdData
369430
) -> dict[str, tuple[float, float]]:

0 commit comments

Comments
 (0)