Skip to content

Commit a3f548d

Browse files
authored
feat(pt): full validation support lmdb format (#5419)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Full validation now supports LMDB-backed validation datasets with snapshot caching and group-by-atom-count evaluation. * Datasets can register and expose data-requirement metadata for downstream tooling. * Added a view adapter to evaluate mixed atom-count validation frames group-by-group. * **Refactor** * Consolidated per-atom-count test-data wrapper into a shared utility. * **Tests** * Added unit and integration tests for mixed-atom-count views, data-requirement registration, and validation iteration. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 1f4db6f commit a3f548d

7 files changed

Lines changed: 346 additions & 27 deletions

File tree

deepmd/dpmodel/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DistributedSameNlocBatchSampler,
1111
LmdbDataReader,
1212
LmdbTestData,
13+
LmdbTestDataNlocView,
1314
SameNlocBatchSampler,
1415
is_lmdb,
1516
make_neighbor_stat_data,
@@ -58,6 +59,7 @@
5859
"FittingNet",
5960
"LmdbDataReader",
6061
"LmdbTestData",
62+
"LmdbTestDataNlocView",
6163
"NativeLayer",
6264
"NativeNet",
6365
"NetworkCollection",

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,11 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
535535
for item in data_requirement:
536536
self._data_requirements[item["key"]] = item
537537

538+
@property
539+
def data_requirements(self) -> list[DataRequirementItem]:
540+
"""Registered data requirements in insertion order."""
541+
return list(self._data_requirements.values())
542+
538543
def print_summary(self, name: str, prob: Any) -> None:
539544
"""Print basic dataset info."""
540545
n_groups = len(self._nloc_groups)
@@ -1286,6 +1291,25 @@ def add(
12861291
"dtype": dtype,
12871292
}
12881293

1294+
def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
1295+
"""Register expected keys from ``DataRequirementItem`` objects.
1296+
1297+
Mirrors :meth:`LmdbDataReader.add_data_requirement` so the same
1298+
requirement list can be forwarded to both the training reader and
1299+
the full-validation test data.
1300+
"""
1301+
for item in data_requirement:
1302+
self.add(
1303+
item["key"],
1304+
ndof=item["ndof"],
1305+
atomic=item["atomic"],
1306+
must=item["must"],
1307+
high_prec=item["high_prec"],
1308+
repeat=item["repeat"],
1309+
default=item["default"],
1310+
dtype=item["dtype"],
1311+
)
1312+
12891313
def _resolve_dtype(self, key: str) -> np.dtype:
12901314
"""Resolve target dtype for a key using registered requirements."""
12911315
if key in self._requirements:
@@ -1444,6 +1468,28 @@ def _stack_frames(
14441468
return result
14451469

14461470

1471+
class LmdbTestDataNlocView:
1472+
"""Thin wrapper exposing a fixed-``nloc`` view of :class:`LmdbTestData`.
1473+
1474+
The underlying :class:`LmdbTestData` groups frames by atom count. This
1475+
view fixes one ``nloc`` group, so ``get_test()`` returns only the frames
1476+
with that atom count and all other attributes (``pbc``, ``mixed_type``,
1477+
…) are forwarded to the underlying object. It lets downstream consumers
1478+
that expect a ``DeepmdData``-style system (one fixed natoms per
1479+
``get_test()``) work on mixed-nloc LMDB datasets without changes.
1480+
"""
1481+
1482+
def __init__(self, lmdb_test_data: "LmdbTestData", nloc: int) -> None:
1483+
self._inner = lmdb_test_data
1484+
self._nloc = nloc
1485+
1486+
def __getattr__(self, name: str) -> Any:
1487+
return getattr(self._inner, name)
1488+
1489+
def get_test(self) -> dict[str, Any]:
1490+
return self._inner.get_test(nloc=self._nloc)
1491+
1492+
14471493
def merge_lmdb(
14481494
src_paths: list[str],
14491495
dst_path: str,

deepmd/entrypoints/test.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from deepmd.dpmodel.utils.lmdb_data import (
1919
LmdbTestData,
20+
LmdbTestDataNlocView,
2021
is_lmdb,
2122
)
2223
from deepmd.infer.deep_dipole import (
@@ -77,24 +78,6 @@
7778
log = logging.getLogger(__name__)
7879

7980

80-
class _LmdbTestDataNlocView:
81-
"""Thin wrapper that makes LmdbTestData.get_test() return a specific nloc group.
82-
83-
Delegates all attributes to the underlying LmdbTestData, but get_test()
84-
returns only frames with the specified nloc.
85-
"""
86-
87-
def __init__(self, lmdb_test_data: LmdbTestData, nloc: int) -> None:
88-
self._inner = lmdb_test_data
89-
self._nloc = nloc
90-
91-
def __getattr__(self, name: str) -> Any:
92-
return getattr(self._inner, name)
93-
94-
def get_test(self) -> dict:
95-
return self._inner.get_test(nloc=self._nloc)
96-
97-
9881
def test(
9982
*,
10083
model: str,
@@ -221,7 +204,7 @@ def test(
221204
for nloc_val in nloc_keys:
222205
label = f"{system} [nloc={nloc_val}]" if len(nloc_keys) > 1 else system
223206
# Create a thin wrapper that returns only this nloc group
224-
data_items.append((_LmdbTestDataNlocView(lmdb_data, nloc_val), label))
207+
data_items.append((LmdbTestDataNlocView(lmdb_data, nloc_val), label))
225208
else:
226209
data = DeepmdData(
227210
system,

deepmd/pt/train/validation.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
import torch.distributed as dist
2525

2626
from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
27+
from deepmd.dpmodel.utils.lmdb_data import (
28+
LmdbTestData,
29+
LmdbTestDataNlocView,
30+
)
2731
from deepmd.pt.utils.auto_batch_size import (
2832
AutoBatchSize,
2933
)
@@ -35,6 +39,9 @@
3539
GLOBAL_PT_FLOAT_PRECISION,
3640
RESERVED_PRECISION_DICT,
3741
)
42+
from deepmd.pt.utils.lmdb_dataset import (
43+
LmdbDataset,
44+
)
3845
from deepmd.pt.utils.utils import (
3946
to_torch_tensor,
4047
)
@@ -55,6 +62,10 @@
5562
log = logging.getLogger(__name__)
5663

5764
if TYPE_CHECKING:
65+
from collections.abc import (
66+
Iterator,
67+
)
68+
5869
from deepmd.utils.data import (
5970
DeepmdData,
6071
)
@@ -229,6 +240,12 @@ def __init__(
229240
if self.rank == 0:
230241
self._initialize_best_checkpoints(restart_training=restart_training)
231242

243+
# Lazily-populated full test snapshot for LMDB validation. Mixed-nloc
244+
# LMDB datasets cannot be stacked as a single (nframes, natoms*3)
245+
# tensor, so we materialize frames grouped by nloc the first time
246+
# full validation runs and reuse the snapshot on subsequent calls.
247+
self._lmdb_test_data: LmdbTestData | None = None
248+
232249
def should_run(self, display_step: int) -> bool:
233250
"""Check whether the current step should trigger full validation."""
234251
if not self.enabled or self.start_step is None:
@@ -348,14 +365,10 @@ def evaluate_all_systems(self) -> dict[str, float]:
348365
if torch.cuda.is_available():
349366
torch.cuda.empty_cache()
350367

351-
system_metrics = []
352-
for dataset in self.validation_data.systems:
353-
if not isinstance(dataset, DeepmdDataSetForLoader):
354-
raise TypeError(
355-
"Full validation expects each dataset in validation_data.systems "
356-
f"to be DeepmdDataSetForLoader, got {type(dataset)!r}."
357-
)
358-
system_metrics.append(self._evaluate_system(dataset.data_system))
368+
system_metrics = [
369+
self._evaluate_system(data_system)
370+
for data_system in self._iter_validation_data_systems()
371+
]
359372

360373
aggregated = weighted_average([metric for metric in system_metrics if metric])
361374
return {
@@ -364,6 +377,53 @@ def evaluate_all_systems(self) -> dict[str, float]:
364377
if metric_key in aggregated
365378
}
366379

380+
def _iter_validation_data_systems(self) -> Iterator[Any]:
381+
"""Yield ``DeepmdData``-like systems to evaluate in this run.
382+
383+
- For ``DpLoaderSet``-style validation data, each entry in
384+
``validation_data.systems`` is a :class:`DeepmdDataSetForLoader`,
385+
and we forward its underlying ``DeepmdData`` instance.
386+
- For ``LmdbDataset`` validation data, we lazily materialize a
387+
:class:`LmdbTestData` snapshot (cached across calls) and yield one
388+
:class:`LmdbTestDataNlocView` per ``nloc`` group, so mixed-nloc
389+
frames can be stacked and evaluated group by group.
390+
"""
391+
validation_data = self.validation_data
392+
if isinstance(validation_data, LmdbDataset):
393+
lmdb_test_data = self._get_lmdb_test_data_snapshot(validation_data)
394+
for nloc in sorted(lmdb_test_data.nloc_groups.keys()):
395+
yield LmdbTestDataNlocView(lmdb_test_data, nloc)
396+
return
397+
398+
for dataset in validation_data.systems:
399+
if not isinstance(dataset, DeepmdDataSetForLoader):
400+
raise TypeError(
401+
"Full validation expects each dataset in validation_data.systems "
402+
f"to be DeepmdDataSetForLoader, got {type(dataset)!r}."
403+
)
404+
yield dataset.data_system
405+
406+
def _get_lmdb_test_data_snapshot(self, lmdb_dataset: LmdbDataset) -> LmdbTestData:
407+
"""Build (once) and return the cached LMDB test snapshot.
408+
409+
Reuses the ``type_map`` and previously-registered
410+
``DataRequirementItem`` entries from the validation dataset so that
411+
the full-validation snapshot sees exactly the same fields and
412+
dtypes as training batches.
413+
"""
414+
if self._lmdb_test_data is not None:
415+
return self._lmdb_test_data
416+
417+
self._lmdb_test_data = LmdbTestData(
418+
lmdb_dataset.lmdb_path,
419+
type_map=list(lmdb_dataset.type_map),
420+
shuffle_test=False,
421+
)
422+
data_requirements = lmdb_dataset.data_requirements
423+
if data_requirements:
424+
self._lmdb_test_data.add_data_requirement(data_requirements)
425+
return self._lmdb_test_data
426+
367427
def _evaluate_system(
368428
self, data_system: DeepmdData
369429
) -> dict[str, tuple[float, float]]:

deepmd/pt/utils/lmdb_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ def mixed_type(self) -> bool:
212212
def batch_size(self) -> int:
213213
return self._reader.batch_size
214214

215+
@property
216+
def type_map(self) -> list[str]:
217+
return self._reader.type_map
218+
219+
@property
220+
def data_requirements(self) -> list[DataRequirementItem]:
221+
return self._reader.data_requirements
222+
215223
def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
216224
self._reader.add_data_requirement(data_requirement)
217225

source/tests/common/dpmodel/test_lmdb_data.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from deepmd.dpmodel.utils.lmdb_data import (
1515
LmdbDataReader,
1616
LmdbTestData,
17+
LmdbTestDataNlocView,
1718
SameNlocBatchSampler,
1819
_expand_indices_by_blocks,
1920
compute_block_targets,
@@ -415,6 +416,27 @@ def test_test_data_get_test_specific_nloc(self):
415416
r12 = td.get_test(nloc=12)
416417
self.assertEqual(r12["coord"].shape, (2, 12 * 3))
417418

419+
def test_test_data_nloc_view(self):
420+
"""LmdbTestDataNlocView delegates attributes and fixes nloc."""
421+
td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False)
422+
td.add("energy", 1, atomic=False, must=False, high_prec=True)
423+
view = LmdbTestDataNlocView(td, 9)
424+
425+
self.assertEqual(view.pbc, td.pbc)
426+
self.assertIs(view.nloc_groups, td.nloc_groups)
427+
428+
expected = td.get_test(nloc=9)
429+
actual = view.get_test()
430+
self.assertEqual(actual["coord"].shape, (4, 9 * 3))
431+
self.assertEqual(actual["type"].shape, (4, 9))
432+
self.assertEqual(actual.keys(), expected.keys())
433+
for key, expected_value in expected.items():
434+
actual_value = actual[key]
435+
if isinstance(expected_value, np.ndarray):
436+
np.testing.assert_array_equal(actual_value, expected_value)
437+
else:
438+
self.assertEqual(actual_value, expected_value)
439+
418440
def test_test_data_get_test_default_mixed(self):
419441
td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False)
420442
td.add("energy", 1, atomic=False, must=False, high_prec=True)
@@ -851,6 +873,66 @@ def test_testdata_repeat_applied(self):
851873
(self._nframes, self._natoms * 3),
852874
)
853875

876+
def test_testdata_add_data_requirement_matches_manual_add(self):
877+
"""DataRequirementItem forwarding matches manual requirement registration."""
878+
from deepmd.utils.data import (
879+
DataRequirementItem,
880+
)
881+
882+
requirements = [
883+
DataRequirementItem(
884+
"drdq",
885+
ndof=6,
886+
atomic=True,
887+
must=False,
888+
high_prec=False,
889+
repeat=2,
890+
default=1.25,
891+
dtype=np.float64,
892+
),
893+
DataRequirementItem(
894+
"aux",
895+
ndof=2,
896+
atomic=False,
897+
must=False,
898+
high_prec=False,
899+
repeat=3,
900+
default=-2.0,
901+
dtype=np.float32,
902+
),
903+
]
904+
manual = LmdbTestData(
905+
self._lmdb_path,
906+
type_map=self._type_map,
907+
shuffle_test=False,
908+
)
909+
forwarded = LmdbTestData(
910+
self._lmdb_path,
911+
type_map=self._type_map,
912+
shuffle_test=False,
913+
)
914+
for item in requirements:
915+
manual.add(
916+
item["key"],
917+
ndof=item["ndof"],
918+
atomic=item["atomic"],
919+
must=item["must"],
920+
high_prec=item["high_prec"],
921+
repeat=item["repeat"],
922+
default=item["default"],
923+
dtype=item["dtype"],
924+
)
925+
forwarded.add_data_requirement(requirements)
926+
927+
manual_result = manual.get_test()
928+
forwarded_result = forwarded.get_test()
929+
for item in requirements:
930+
key = item["key"]
931+
self.assertEqual(forwarded_result[f"find_{key}"], 0.0)
932+
self.assertEqual(forwarded_result[key].shape, manual_result[key].shape)
933+
self.assertEqual(forwarded_result[key].dtype, manual_result[key].dtype)
934+
np.testing.assert_array_equal(forwarded_result[key], manual_result[key])
935+
854936
def test_testdata_missing_key_not_found(self):
855937
"""Keys absent from LMDB frames get find_*=0.0 in get_test()."""
856938
tmpdir = tempfile.TemporaryDirectory()

0 commit comments

Comments
 (0)