Skip to content

Commit a27d0ae

Browse files
committed
fixup
1 parent b615276 commit a27d0ae

5 files changed

Lines changed: 236 additions & 4 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,11 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
535535
for item in data_requirement:
536536
self._data_requirements[item["key"]] = item
537537

538+
@property
539+
def data_requirements(self) -> list[DataRequirementItem]:
540+
"""Registered data requirements in insertion order."""
541+
return list(self._data_requirements.values())
542+
538543
def print_summary(self, name: str, prob: Any) -> None:
539544
"""Print basic dataset info."""
540545
n_groups = len(self._nloc_groups)

deepmd/pt/train/validation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,20 +407,19 @@ def _get_lmdb_test_data_snapshot(self, lmdb_dataset: LmdbDataset) -> LmdbTestDat
407407
"""Build (once) and return the cached LMDB test snapshot.
408408
409409
Reuses the ``type_map`` and previously-registered
410-
``DataRequirementItem`` entries on ``LmdbDataset._reader`` so that
410+
``DataRequirementItem`` entries from the validation dataset so that
411411
the full-validation snapshot sees exactly the same fields and
412412
dtypes as training batches.
413413
"""
414414
if self._lmdb_test_data is not None:
415415
return self._lmdb_test_data
416416

417-
reader = lmdb_dataset._reader
418417
self._lmdb_test_data = LmdbTestData(
419418
lmdb_dataset.lmdb_path,
420-
type_map=list(reader._type_map),
419+
type_map=list(lmdb_dataset.type_map),
421420
shuffle_test=False,
422421
)
423-
data_requirements = list(reader._data_requirements.values())
422+
data_requirements = lmdb_dataset.data_requirements
424423
if data_requirements:
425424
self._lmdb_test_data.add_data_requirement(data_requirements)
426425
return self._lmdb_test_data

deepmd/pt/utils/lmdb_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ def mixed_type(self) -> bool:
212212
def batch_size(self) -> int:
213213
return self._reader.batch_size
214214

215+
@property
216+
def type_map(self) -> list[str]:
217+
return self._reader.type_map
218+
219+
@property
220+
def data_requirements(self) -> list[DataRequirementItem]:
221+
return self._reader.data_requirements
222+
215223
def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
216224
self._reader.add_data_requirement(data_requirement)
217225

source/tests/common/dpmodel/test_lmdb_data.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from deepmd.dpmodel.utils.lmdb_data import (
1515
LmdbDataReader,
1616
LmdbTestData,
17+
LmdbTestDataNlocView,
1718
SameNlocBatchSampler,
1819
_expand_indices_by_blocks,
1920
compute_block_targets,
@@ -415,6 +416,27 @@ def test_test_data_get_test_specific_nloc(self):
415416
r12 = td.get_test(nloc=12)
416417
self.assertEqual(r12["coord"].shape, (2, 12 * 3))
417418

419+
def test_test_data_nloc_view(self):
420+
"""LmdbTestDataNlocView delegates attributes and fixes nloc."""
421+
td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False)
422+
td.add("energy", 1, atomic=False, must=False, high_prec=True)
423+
view = LmdbTestDataNlocView(td, 9)
424+
425+
self.assertEqual(view.pbc, td.pbc)
426+
self.assertIs(view.nloc_groups, td.nloc_groups)
427+
428+
expected = td.get_test(nloc=9)
429+
actual = view.get_test()
430+
self.assertEqual(actual["coord"].shape, (4, 9 * 3))
431+
self.assertEqual(actual["type"].shape, (4, 9))
432+
self.assertEqual(actual.keys(), expected.keys())
433+
for key, expected_value in expected.items():
434+
actual_value = actual[key]
435+
if isinstance(expected_value, np.ndarray):
436+
np.testing.assert_array_equal(actual_value, expected_value)
437+
else:
438+
self.assertEqual(actual_value, expected_value)
439+
418440
def test_test_data_get_test_default_mixed(self):
419441
td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False)
420442
td.add("energy", 1, atomic=False, must=False, high_prec=True)
@@ -851,6 +873,66 @@ def test_testdata_repeat_applied(self):
851873
(self._nframes, self._natoms * 3),
852874
)
853875

876+
def test_testdata_add_data_requirement_matches_manual_add(self):
877+
"""DataRequirementItem forwarding matches manual requirement registration."""
878+
from deepmd.utils.data import (
879+
DataRequirementItem,
880+
)
881+
882+
requirements = [
883+
DataRequirementItem(
884+
"drdq",
885+
ndof=6,
886+
atomic=True,
887+
must=False,
888+
high_prec=False,
889+
repeat=2,
890+
default=1.25,
891+
dtype=np.float64,
892+
),
893+
DataRequirementItem(
894+
"aux",
895+
ndof=2,
896+
atomic=False,
897+
must=False,
898+
high_prec=False,
899+
repeat=3,
900+
default=-2.0,
901+
dtype=np.float32,
902+
),
903+
]
904+
manual = LmdbTestData(
905+
self._lmdb_path,
906+
type_map=self._type_map,
907+
shuffle_test=False,
908+
)
909+
forwarded = LmdbTestData(
910+
self._lmdb_path,
911+
type_map=self._type_map,
912+
shuffle_test=False,
913+
)
914+
for item in requirements:
915+
manual.add(
916+
item["key"],
917+
ndof=item["ndof"],
918+
atomic=item["atomic"],
919+
must=item["must"],
920+
high_prec=item["high_prec"],
921+
repeat=item["repeat"],
922+
default=item["default"],
923+
dtype=item["dtype"],
924+
)
925+
forwarded.add_data_requirement(requirements)
926+
927+
manual_result = manual.get_test()
928+
forwarded_result = forwarded.get_test()
929+
for item in requirements:
930+
key = item["key"]
931+
self.assertEqual(forwarded_result[f"find_{key}"], 0.0)
932+
self.assertEqual(forwarded_result[key].shape, manual_result[key].shape)
933+
self.assertEqual(forwarded_result[key].dtype, manual_result[key].dtype)
934+
np.testing.assert_array_equal(forwarded_result[key], manual_result[key])
935+
854936
def test_testdata_missing_key_not_found(self):
855937
"""Keys absent from LMDB frames get find_*=0.0 in get_test()."""
856938
tmpdir = tempfile.TemporaryDirectory()

source/tests/pt/test_validation.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from pathlib import (
99
Path,
1010
)
11+
from unittest.mock import (
12+
patch,
13+
)
1114

15+
import lmdb
16+
import msgpack
17+
import numpy as np
1218
import torch
1319
from dargs.dargs import (
1420
ArgumentValueError,
@@ -20,6 +26,9 @@
2026
FullValidator,
2127
resolve_full_validation_start_step,
2228
)
29+
from deepmd.pt.utils.lmdb_dataset import (
30+
LmdbDataset,
31+
)
2332
from deepmd.utils.argcheck import (
2433
normalize,
2534
)
@@ -45,6 +54,85 @@ def get_dim_aparam(self) -> int:
4554
return 0
4655

4756

57+
def _make_lmdb_frame(natoms: int, seed: int) -> dict:
58+
"""Create one synthetic LMDB frame for full-validation tests."""
59+
rng = np.random.RandomState(seed)
60+
n_type0 = max(1, natoms // 3)
61+
n_type1 = natoms - n_type0
62+
atype = np.array([0] * n_type0 + [1] * n_type1, dtype=np.int64)
63+
return {
64+
"atom_names": ["O", "H"],
65+
"atom_numbs": [
66+
{
67+
"type": "<i8",
68+
"shape": (1,),
69+
"data": np.array([n_type0], dtype=np.int64).tobytes(),
70+
},
71+
{
72+
"type": "<i8",
73+
"shape": (1,),
74+
"data": np.array([n_type1], dtype=np.int64).tobytes(),
75+
},
76+
],
77+
"atom_types": {
78+
"type": "<i8",
79+
"shape": (natoms,),
80+
"data": atype.tobytes(),
81+
},
82+
"coords": {
83+
"type": "<f8",
84+
"shape": (natoms, 3),
85+
"data": rng.randn(natoms, 3).astype(np.float64).tobytes(),
86+
},
87+
"cells": {
88+
"type": "<f8",
89+
"shape": (3, 3),
90+
"data": (np.eye(3) * 10.0).astype(np.float64).tobytes(),
91+
},
92+
"energies": {
93+
"type": "<f8",
94+
"shape": (1,),
95+
"data": rng.randn(1).astype(np.float64).tobytes(),
96+
},
97+
"forces": {
98+
"type": "<f8",
99+
"shape": (natoms, 3),
100+
"data": rng.randn(natoms, 3).astype(np.float64).tobytes(),
101+
},
102+
}
103+
104+
105+
def _create_mixed_nloc_lmdb(path: str) -> str:
106+
"""Create a mixed-nloc LMDB dataset with 6, 9, and 12-atom frames."""
107+
frame_specs = [(6, 4), (9, 4), (12, 2)]
108+
total_frames = sum(count for _, count in frame_specs)
109+
env = lmdb.open(path, map_size=10 * 1024 * 1024)
110+
with env.begin(write=True) as txn:
111+
metadata = {
112+
"nframes": total_frames,
113+
"frame_idx_fmt": "012d",
114+
"type_map": ["O", "H"],
115+
"system_info": {
116+
"natoms": [2, 4],
117+
"formula": "mixed",
118+
},
119+
}
120+
txn.put(b"__metadata__", msgpack.packb(metadata, use_bin_type=True))
121+
frame_idx = 0
122+
for natoms, count in frame_specs:
123+
for _ in range(count):
124+
txn.put(
125+
format(frame_idx, "012d").encode(),
126+
msgpack.packb(
127+
_make_lmdb_frame(natoms=natoms, seed=frame_idx),
128+
use_bin_type=True,
129+
),
130+
)
131+
frame_idx += 1
132+
env.close()
133+
return path
134+
135+
48136
def _make_single_task_config() -> dict:
49137
return {
50138
"model": deepcopy(model_se_e2_a),
@@ -192,6 +280,56 @@ def test_full_validator_restores_top_k_checkpoints(self) -> None:
192280
["best.ckpt-10.t-2.pt", "best.ckpt-20.t-1.pt"],
193281
)
194282

283+
def test_full_validator_lmdb_full_validation_iterates_nloc_groups(self) -> None:
284+
with tempfile.TemporaryDirectory() as tmpdir:
285+
lmdb_path = _create_mixed_nloc_lmdb(f"{tmpdir}/mixed.lmdb")
286+
validation_data = LmdbDataset(
287+
lmdb_path,
288+
type_map=["O", "H"],
289+
batch_size=2,
290+
)
291+
validator = FullValidator(
292+
validating_params={
293+
"full_validation": True,
294+
"validation_freq": 1,
295+
"save_best": False,
296+
"max_best_ckpt": 1,
297+
"validation_metric": "E:MAE",
298+
"full_val_file": "val.log",
299+
"full_val_start": 0.0,
300+
},
301+
validation_data=validation_data,
302+
model=_DummyModel(),
303+
train_infos={},
304+
num_steps=10,
305+
rank=0,
306+
zero_stage=0,
307+
restart_training=False,
308+
)
309+
observed_natoms = []
310+
311+
def fake_evaluate_system(data_system):
312+
test_data = data_system.get_test()
313+
natoms = int(test_data["type"].shape[1])
314+
nframes = int(test_data["coord"].shape[0])
315+
observed_natoms.append(natoms)
316+
return {
317+
"mae_e_per_atom": (float(natoms), nframes),
318+
"rmse_e_per_atom": (float(natoms), nframes),
319+
}
320+
321+
with patch.object(
322+
validator,
323+
"_evaluate_system",
324+
side_effect=fake_evaluate_system,
325+
) as evaluate_system:
326+
metrics = validator.evaluate_all_systems()
327+
328+
self.assertEqual(observed_natoms, [6, 9, 12])
329+
self.assertEqual(evaluate_system.call_count, 3)
330+
self.assertAlmostEqual(metrics["mae_e_per_atom"], 8.4)
331+
self.assertAlmostEqual(metrics["rmse_e_per_atom"], np.sqrt(75.6))
332+
195333

196334
class TestValidationArgcheck(unittest.TestCase):
197335
def test_normalize_rejects_missing_validation_data(self) -> None:

0 commit comments

Comments
 (0)