88from pathlib import (
99 Path ,
1010)
11+ from unittest .mock import (
12+ patch ,
13+ )
1114
15+ import lmdb
16+ import msgpack
17+ import numpy as np
1218import torch
1319from dargs .dargs import (
1420 ArgumentValueError ,
2026 FullValidator ,
2127 resolve_full_validation_start_step ,
2228)
29+ from deepmd .pt .utils .lmdb_dataset import (
30+ LmdbDataset ,
31+ )
2332from deepmd .utils .argcheck import (
2433 normalize ,
2534)
@@ -45,6 +54,85 @@ def get_dim_aparam(self) -> int:
4554 return 0
4655
4756
57+ def _make_lmdb_frame (natoms : int , seed : int ) -> dict :
58+ """Create one synthetic LMDB frame for full-validation tests."""
59+ rng = np .random .RandomState (seed )
60+ n_type0 = max (1 , natoms // 3 )
61+ n_type1 = natoms - n_type0
62+ atype = np .array ([0 ] * n_type0 + [1 ] * n_type1 , dtype = np .int64 )
63+ return {
64+ "atom_names" : ["O" , "H" ],
65+ "atom_numbs" : [
66+ {
67+ "type" : "<i8" ,
68+ "shape" : (1 ,),
69+ "data" : np .array ([n_type0 ], dtype = np .int64 ).tobytes (),
70+ },
71+ {
72+ "type" : "<i8" ,
73+ "shape" : (1 ,),
74+ "data" : np .array ([n_type1 ], dtype = np .int64 ).tobytes (),
75+ },
76+ ],
77+ "atom_types" : {
78+ "type" : "<i8" ,
79+ "shape" : (natoms ,),
80+ "data" : atype .tobytes (),
81+ },
82+ "coords" : {
83+ "type" : "<f8" ,
84+ "shape" : (natoms , 3 ),
85+ "data" : rng .randn (natoms , 3 ).astype (np .float64 ).tobytes (),
86+ },
87+ "cells" : {
88+ "type" : "<f8" ,
89+ "shape" : (3 , 3 ),
90+ "data" : (np .eye (3 ) * 10.0 ).astype (np .float64 ).tobytes (),
91+ },
92+ "energies" : {
93+ "type" : "<f8" ,
94+ "shape" : (1 ,),
95+ "data" : rng .randn (1 ).astype (np .float64 ).tobytes (),
96+ },
97+ "forces" : {
98+ "type" : "<f8" ,
99+ "shape" : (natoms , 3 ),
100+ "data" : rng .randn (natoms , 3 ).astype (np .float64 ).tobytes (),
101+ },
102+ }
103+
104+
105+ def _create_mixed_nloc_lmdb (path : str ) -> str :
106+ """Create a mixed-nloc LMDB dataset with 6, 9, and 12-atom frames."""
107+ frame_specs = [(6 , 4 ), (9 , 4 ), (12 , 2 )]
108+ total_frames = sum (count for _ , count in frame_specs )
109+ env = lmdb .open (path , map_size = 10 * 1024 * 1024 )
110+ with env .begin (write = True ) as txn :
111+ metadata = {
112+ "nframes" : total_frames ,
113+ "frame_idx_fmt" : "012d" ,
114+ "type_map" : ["O" , "H" ],
115+ "system_info" : {
116+ "natoms" : [2 , 4 ],
117+ "formula" : "mixed" ,
118+ },
119+ }
120+ txn .put (b"__metadata__" , msgpack .packb (metadata , use_bin_type = True ))
121+ frame_idx = 0
122+ for natoms , count in frame_specs :
123+ for _ in range (count ):
124+ txn .put (
125+ format (frame_idx , "012d" ).encode (),
126+ msgpack .packb (
127+ _make_lmdb_frame (natoms = natoms , seed = frame_idx ),
128+ use_bin_type = True ,
129+ ),
130+ )
131+ frame_idx += 1
132+ env .close ()
133+ return path
134+
135+
48136def _make_single_task_config () -> dict :
49137 return {
50138 "model" : deepcopy (model_se_e2_a ),
@@ -192,6 +280,56 @@ def test_full_validator_restores_top_k_checkpoints(self) -> None:
192280 ["best.ckpt-10.t-2.pt" , "best.ckpt-20.t-1.pt" ],
193281 )
194282
283+ def test_full_validator_lmdb_full_validation_iterates_nloc_groups (self ) -> None :
284+ with tempfile .TemporaryDirectory () as tmpdir :
285+ lmdb_path = _create_mixed_nloc_lmdb (f"{ tmpdir } /mixed.lmdb" )
286+ validation_data = LmdbDataset (
287+ lmdb_path ,
288+ type_map = ["O" , "H" ],
289+ batch_size = 2 ,
290+ )
291+ validator = FullValidator (
292+ validating_params = {
293+ "full_validation" : True ,
294+ "validation_freq" : 1 ,
295+ "save_best" : False ,
296+ "max_best_ckpt" : 1 ,
297+ "validation_metric" : "E:MAE" ,
298+ "full_val_file" : "val.log" ,
299+ "full_val_start" : 0.0 ,
300+ },
301+ validation_data = validation_data ,
302+ model = _DummyModel (),
303+ train_infos = {},
304+ num_steps = 10 ,
305+ rank = 0 ,
306+ zero_stage = 0 ,
307+ restart_training = False ,
308+ )
309+ observed_natoms = []
310+
311+ def fake_evaluate_system (data_system ):
312+ test_data = data_system .get_test ()
313+ natoms = int (test_data ["type" ].shape [1 ])
314+ nframes = int (test_data ["coord" ].shape [0 ])
315+ observed_natoms .append (natoms )
316+ return {
317+ "mae_e_per_atom" : (float (natoms ), nframes ),
318+ "rmse_e_per_atom" : (float (natoms ), nframes ),
319+ }
320+
321+ with patch .object (
322+ validator ,
323+ "_evaluate_system" ,
324+ side_effect = fake_evaluate_system ,
325+ ) as evaluate_system :
326+ metrics = validator .evaluate_all_systems ()
327+
328+ self .assertEqual (observed_natoms , [6 , 9 , 12 ])
329+ self .assertEqual (evaluate_system .call_count , 3 )
330+ self .assertAlmostEqual (metrics ["mae_e_per_atom" ], 8.4 )
331+ self .assertAlmostEqual (metrics ["rmse_e_per_atom" ], np .sqrt (75.6 ))
332+
195333
196334class TestValidationArgcheck (unittest .TestCase ):
197335 def test_normalize_rejects_missing_validation_data (self ) -> None :
0 commit comments